From 3ee46631ceca672db0bd5238f3f59b096e3bf16f Mon Sep 17 00:00:00 2001 From: Sawan Srivastava Date: Thu, 3 Jul 2025 23:26:18 +0000 Subject: [PATCH 01/38] Added basic bool approx class Signed-off-by: Sawan Srivastava Implemented canApproximate and rewrite for single clause bool queries in ApproximateBooleanQuery Signed-off-by: Sawan Srivastava Added ApproximateScoreQuery wrapping in BoolQueryBuilder Signed-off-by: Sawan Srivastava Enabled approximation in single clause bool queries by calling setContext in ApproximateBooleanQuery Signed-off-by: Sawan Srivastava Removed redundancy by adding pattern matching to ApproximateScoreQuery check Signed-off-by: Sawan Srivastava setContext in canApproximate to remove redundant context variable Signed-off-by: Sawan Srivastava fix rewrite method to default to original query for multi-clause case Signed-off-by: Sawan Srivastava Prevent multi-clause bool queries from using ApproximateBooleanQuery (for now) Signed-off-by: Sawan Srivastava Fix failing tests for single clause boolean queries Signed-off-by: Sawan Srivastava fix nested single clause boolean queries Signed-off-by: Sawan Srivastava Enabled proper recursive rewriting to ensure clauses are properly rewritten Signed-off-by: Sawan Srivastava Unwrap boolean query in setContext Signed-off-by: Sawan Srivastava Removed redundant unwrap methods Signed-off-by: Sawan Srivastava Fixed more integration tests Signed-off-by: Sawan Srivastava Actually check whether nested query can be approximated Signed-off-by: Sawan Srivastava --- .../index/query/BoolQueryBuilder.java | 12 +- .../approximate/ApproximateBooleanQuery.java | 115 ++++++++++++++++++ .../approximate/ApproximateScoreQuery.java | 7 +- .../index/query/BoolQueryBuilderTests.java | 19 ++- .../query/MultiMatchQueryBuilderTests.java | 3 + .../query/QueryStringQueryBuilderTests.java | 2 + .../query/SimpleQueryStringBuilderTests.java | 3 + 7 files changed, 153 insertions(+), 8 deletions(-) create mode 100644 server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java diff --git a/server/src/main/java/org/opensearch/index/query/BoolQueryBuilder.java b/server/src/main/java/org/opensearch/index/query/BoolQueryBuilder.java index f2e7565c885c1..bb4a345374917 100644 --- a/server/src/main/java/org/opensearch/index/query/BoolQueryBuilder.java +++ b/server/src/main/java/org/opensearch/index/query/BoolQueryBuilder.java @@ -45,6 +45,8 @@ import org.opensearch.core.xcontent.ObjectParser; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.approximate.ApproximateBooleanQuery; +import org.opensearch.search.approximate.ApproximateScoreQuery; import java.io.IOException; import java.util.ArrayList; @@ -335,7 +337,15 @@ protected Query doToQuery(QueryShardContext context) throws IOException { } Query query = Queries.applyMinimumShouldMatch(booleanQuery, minimumShouldMatch); - return adjustPureNegative ? fixNegativeQueryIfNeeded(query) : query; + + if (adjustPureNegative) { + query = fixNegativeQueryIfNeeded(query); + } + + // TODO: Figure out why multi-clause breaks testPhrasePrefix() in HighlighterWithAnalyzersTests.java + return ((BooleanQuery) query).clauses().size() == 1 + ? new ApproximateScoreQuery(query, new ApproximateBooleanQuery((BooleanQuery) query)) + : query; } private static void addBooleanClauses( diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java new file mode 100644 index 0000000000000..ba86d0dd0eee1 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java @@ -0,0 +1,115 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.approximate; + +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.opensearch.search.internal.SearchContext; + +import java.util.List; + +/** + * An approximate-able version of {@link BooleanQuery}. For single clause boolean queries, + * it unwraps the query into the singular clause and ensures approximation is applied. + */ +public class ApproximateBooleanQuery extends ApproximateQuery { + public final BooleanQuery boolQuery; + private final int size; + private final List clauses; + private ApproximateBooleanQuery booleanQuery; + public boolean isUnwrapped = false; + + public ApproximateBooleanQuery(BooleanQuery boolQuery) { + this(boolQuery, SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO); + } + + protected ApproximateBooleanQuery(BooleanQuery boolQuery, int size) { + this.boolQuery = boolQuery; + this.size = size; + this.clauses = boolQuery.clauses(); + } + + public ApproximateBooleanQuery getBooleanQuery() { + return booleanQuery; + } + + public Query getClauseQuery() { + return clauses.get(0).query(); + } + + public static Query unwrap(Query unwrapBoolQuery) { + Query clauseQuery = unwrapBoolQuery instanceof ApproximateBooleanQuery + ? ((ApproximateBooleanQuery) unwrapBoolQuery).getClauseQuery() + : ((BooleanQuery) unwrapBoolQuery).clauses().get(0).query(); + if (clauseQuery instanceof ApproximateBooleanQuery nestedBool) { + return unwrap(nestedBool); + } else { + return clauseQuery; + } + } + + @Override + protected boolean canApproximate(SearchContext context) { + booleanQuery = this; + if (context == null) { + return false; + } + + // Don't approximate if we need accurate total hits + if (context.trackTotalHitsUpTo() == SearchContext.TRACK_TOTAL_HITS_ACCURATE) { + return false; + } + + // Don't approximate if we have aggregations + if (context.aggregations() != null) { + return false; + } + + // For single clause boolean queries, check if the clause can be approximated + if (clauses.size() == 1 && clauses.get(0).occur() != BooleanClause.Occur.MUST_NOT) { + BooleanClause singleClause = clauses.get(0); + Query clauseQuery = singleClause.query(); + + // If the clause is already an ApproximateScoreQuery, we can approximate + set context + if (clauseQuery instanceof ApproximateScoreQuery approximateScoreQuery) { + if (approximateScoreQuery.getApproximationQuery() instanceof ApproximateBooleanQuery nestedBool) { + return nestedBool.canApproximate(context); + } + return approximateScoreQuery.getApproximationQuery().canApproximate(context); + } + } + + return false; + } + + @Override + public String toString(String s) { + return "ApproximateBooleanQuery(" + boolQuery.toString(s) + ")"; + } + + @Override + public void visit(QueryVisitor queryVisitor) { + boolQuery.visit(queryVisitor); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ApproximateBooleanQuery that = (ApproximateBooleanQuery) o; + return size == that.size && boolQuery.equals(that.boolQuery); + } + + @Override + public int hashCode() { + return boolQuery.hashCode(); + } +} diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java index be1b6eed5333d..eb66e2a3779ec 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java @@ -9,6 +9,7 @@ package org.opensearch.search.approximate; import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; @@ -50,6 +51,7 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { } Query rewritten = resolvedQuery.rewrite(indexSearcher); if (rewritten != resolvedQuery) { + // To make sure that query goes through entire rewrite process resolvedQuery = rewritten; } return this; @@ -57,7 +59,10 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { public void setContext(SearchContext context) { resolvedQuery = approximationQuery.canApproximate(context) ? approximationQuery : originalQuery; - }; + if (resolvedQuery instanceof ApproximateBooleanQuery || resolvedQuery instanceof BooleanQuery) { + resolvedQuery = ApproximateBooleanQuery.unwrap(resolvedQuery); + } + } @Override public String toString(String s) { diff --git a/server/src/test/java/org/opensearch/index/query/BoolQueryBuilderTests.java b/server/src/test/java/org/opensearch/index/query/BoolQueryBuilderTests.java index 85e1d0f00c661..26b61880df929 100644 --- a/server/src/test/java/org/opensearch/index/query/BoolQueryBuilderTests.java +++ b/server/src/test/java/org/opensearch/index/query/BoolQueryBuilderTests.java @@ -50,6 +50,7 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParseException; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.approximate.ApproximateBooleanQuery; import org.opensearch.search.approximate.ApproximateMatchAllQuery; import org.opensearch.search.approximate.ApproximateScoreQuery; import org.opensearch.search.internal.ContextIndexSearcher; @@ -119,8 +120,14 @@ protected void doAssertLuceneQuery(BoolQueryBuilder queryBuilder, Query query, Q assertThat(query, instanceOf(ApproximateScoreQuery.class)); assertThat(((ApproximateScoreQuery) query).getOriginalQuery(), instanceOf(MatchAllDocsQuery.class)); } else if (query instanceof MatchNoDocsQuery == false) { - assertThat(query, instanceOf(BooleanQuery.class)); - BooleanQuery booleanQuery = (BooleanQuery) query; + BooleanQuery booleanQuery; + if (query instanceof ApproximateScoreQuery) { // true for single clause cases + assertThat(((ApproximateScoreQuery) query).getOriginalQuery(), instanceOf(BooleanQuery.class)); + booleanQuery = (BooleanQuery) ((ApproximateScoreQuery) query).getOriginalQuery(); + } else { + assertThat(query, instanceOf(BooleanQuery.class)); + booleanQuery = (BooleanQuery) query; + } if (queryBuilder.adjustPureNegative()) { boolean isNegative = true; for (BooleanClause clause : clauses) { @@ -210,14 +217,14 @@ public void testMinShouldMatchFilterWithoutShouldClauses() throws Exception { BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); boolQueryBuilder.filter(new BoolQueryBuilder().must(new MatchAllQueryBuilder())); Query query = boolQueryBuilder.toQuery(createShardContext()); - assertThat(query, instanceOf(BooleanQuery.class)); - BooleanQuery booleanQuery = (BooleanQuery) query; + assertThat(((ApproximateScoreQuery) query).getApproximationQuery(), instanceOf(ApproximateBooleanQuery.class)); + BooleanQuery booleanQuery = (BooleanQuery) ((ApproximateScoreQuery) query).getOriginalQuery(); assertThat(booleanQuery.getMinimumNumberShouldMatch(), equalTo(0)); assertThat(booleanQuery.clauses().size(), equalTo(1)); BooleanClause booleanClause = booleanQuery.clauses().get(0); assertThat(booleanClause.occur(), equalTo(BooleanClause.Occur.FILTER)); - assertThat(booleanClause.query(), instanceOf(BooleanQuery.class)); - BooleanQuery innerBooleanQuery = (BooleanQuery) booleanClause.query(); + assertThat(((ApproximateScoreQuery) booleanClause.query()).getOriginalQuery(), instanceOf(BooleanQuery.class)); + BooleanQuery innerBooleanQuery = (BooleanQuery) ((ApproximateScoreQuery) booleanClause.query()).getOriginalQuery(); // we didn't set minimum should match initially, there are no should clauses so it should be 0 assertThat(innerBooleanQuery.getMinimumNumberShouldMatch(), equalTo(0)); assertThat(innerBooleanQuery.clauses().size(), equalTo(1)); diff --git a/server/src/test/java/org/opensearch/index/query/MultiMatchQueryBuilderTests.java b/server/src/test/java/org/opensearch/index/query/MultiMatchQueryBuilderTests.java index d352d54b6f02a..ce7fdaf84f8ac 100644 --- a/server/src/test/java/org/opensearch/index/query/MultiMatchQueryBuilderTests.java +++ b/server/src/test/java/org/opensearch/index/query/MultiMatchQueryBuilderTests.java @@ -55,6 +55,7 @@ import org.opensearch.index.query.MultiMatchQueryBuilder.Type; import org.opensearch.index.search.MatchQuery; import org.opensearch.lucene.queries.ExtendedCommonTermsQuery; +import org.opensearch.search.approximate.ApproximateBooleanQuery; import org.opensearch.search.approximate.ApproximateScoreQuery; import org.opensearch.test.AbstractQueryTestCase; @@ -562,6 +563,7 @@ public void testWithStopWords() throws Exception { query = new BoolQueryBuilder().should(new MultiMatchQueryBuilder("the").field(TEXT_FIELD_NAME).analyzer("stop")) .toQuery(createShardContext()); expected = new BooleanQuery.Builder().add(new MatchNoDocsQuery(), BooleanClause.Occur.SHOULD).build(); + expected = new ApproximateScoreQuery(expected, new ApproximateBooleanQuery((BooleanQuery) expected)); assertEquals(expected, query); query = new BoolQueryBuilder().should( @@ -571,6 +573,7 @@ public void testWithStopWords() throws Exception { new DisjunctionMaxQuery(Arrays.asList(new MatchNoDocsQuery(), new MatchNoDocsQuery()), 0f), BooleanClause.Occur.SHOULD ).build(); + expected = new ApproximateScoreQuery(expected, new ApproximateBooleanQuery((BooleanQuery) expected)); assertEquals(expected, query); } diff --git a/server/src/test/java/org/opensearch/index/query/QueryStringQueryBuilderTests.java b/server/src/test/java/org/opensearch/index/query/QueryStringQueryBuilderTests.java index ea31d2680d4ec..6a4034ee0ae40 100644 --- a/server/src/test/java/org/opensearch/index/query/QueryStringQueryBuilderTests.java +++ b/server/src/test/java/org/opensearch/index/query/QueryStringQueryBuilderTests.java @@ -77,6 +77,7 @@ import org.opensearch.index.mapper.MapperService; import org.opensearch.index.search.QueryStringQueryParser; import org.opensearch.lucene.queries.BlendedTermQuery; +import org.opensearch.search.approximate.ApproximateBooleanQuery; import org.opensearch.search.approximate.ApproximatePointRangeQuery; import org.opensearch.search.approximate.ApproximateScoreQuery; import org.opensearch.test.AbstractQueryTestCase; @@ -1454,6 +1455,7 @@ public void testWithStopWords() throws Exception { query = new BoolQueryBuilder().should(new QueryStringQueryBuilder("the").field(TEXT_FIELD_NAME).analyzer("stop")) .toQuery(createShardContext()); expected = new BooleanQuery.Builder().add(new BooleanQuery.Builder().build(), BooleanClause.Occur.SHOULD).build(); + expected = new ApproximateScoreQuery(expected, new ApproximateBooleanQuery((BooleanQuery) expected)); assertEquals(expected, query); query = new BoolQueryBuilder().should( diff --git a/server/src/test/java/org/opensearch/index/query/SimpleQueryStringBuilderTests.java b/server/src/test/java/org/opensearch/index/query/SimpleQueryStringBuilderTests.java index 0edd387ea9c6f..48080f9a13363 100644 --- a/server/src/test/java/org/opensearch/index/query/SimpleQueryStringBuilderTests.java +++ b/server/src/test/java/org/opensearch/index/query/SimpleQueryStringBuilderTests.java @@ -56,6 +56,8 @@ import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.common.settings.Settings; import org.opensearch.index.search.SimpleQueryStringQueryParser; +import org.opensearch.search.approximate.ApproximateBooleanQuery; +import org.opensearch.search.approximate.ApproximateScoreQuery; import org.opensearch.test.AbstractQueryTestCase; import java.io.IOException; @@ -775,6 +777,7 @@ public void testWithStopWords() throws Exception { query = new BoolQueryBuilder().should(new SimpleQueryStringBuilder("the").field(TEXT_FIELD_NAME).analyzer("stop")) .toQuery(createShardContext()); expected = new BooleanQuery.Builder().add(new MatchNoDocsQuery(), BooleanClause.Occur.SHOULD).build(); + expected = new ApproximateScoreQuery(expected, new ApproximateBooleanQuery((BooleanQuery) expected)); assertEquals(expected, query); query = new BoolQueryBuilder().should( From e0621f30e4f0f2a38749e909cb2dfb32577e763e Mon Sep 17 00:00:00 2001 From: Sawan Srivastava Date: Fri, 11 Jul 2025 05:23:32 +0000 Subject: [PATCH 02/38] Ensure setcontext is called on query Signed-off-by: Sawan Srivastava --- .../search/approximate/ApproximateScoreQuery.java | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java index eb66e2a3779ec..df5b2ec8ffaab 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java @@ -61,6 +61,14 @@ public void setContext(SearchContext context) { resolvedQuery = approximationQuery.canApproximate(context) ? approximationQuery : originalQuery; if (resolvedQuery instanceof ApproximateBooleanQuery || resolvedQuery instanceof BooleanQuery) { resolvedQuery = ApproximateBooleanQuery.unwrap(resolvedQuery); + if (resolvedQuery instanceof ApproximateScoreQuery appxResolved) { + appxResolved.setContext(context); + } + try { + resolvedQuery = resolvedQuery.rewrite(context.searcher()); + } catch (IOException e) { + throw new RuntimeException(e); + } } } From d2f7fe3a080d1873cce9230b0e6b6e7981d1dc11 Mon Sep 17 00:00:00 2001 From: Sawan Srivastava Date: Thu, 17 Jul 2025 21:59:02 +0000 Subject: [PATCH 03/38] implement createWeight in ApproximateBooleanQuery Signed-off-by: Sawan Srivastava --- .../approximate/ApproximateBooleanQuery.java | 40 +++++++++++++++++-- 1 file changed, 36 insertions(+), 4 deletions(-) diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java index ba86d0dd0eee1..84aca430c7ec5 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java @@ -8,12 +8,11 @@ package org.opensearch.search.approximate; -import org.apache.lucene.search.BooleanClause; -import org.apache.lucene.search.BooleanQuery; -import org.apache.lucene.search.Query; -import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.*; import org.opensearch.search.internal.SearchContext; +import java.io.IOException; import java.util.List; /** @@ -90,6 +89,39 @@ protected boolean canApproximate(SearchContext context) { return false; } + @Override + public ConstantScoreWeight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost){ + return new ConstantScoreWeight(this, boost) { + + /** + * @param ctx + * @return {@code true} if the object can be cached against a given leaf + */ + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return false; + } + + /** + * Get a {@link ScorerSupplier}, which allows knowing the cost of the {@link Scorer} before + * building it. A scorer supplier for the same {@link LeafReaderContext} instance may be requested + * multiple times as part of a single search call. + * + *

Note: It must return null if the scorer is null. + * + * @param context the leaf reader context + * @return a {@link ScorerSupplier} providing the scorer, or null if scorer is null + * @throws IOException if an IOException occurs + * @see Scorer + * @see DefaultScorerSupplier + */ + @Override + public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { + return null; + } + }; + } + @Override public String toString(String s) { return "ApproximateBooleanQuery(" + boolQuery.toString(s) + ")"; From f58229afb2079be8706d883826dbf9594f5cfa29 Mon Sep 17 00:00:00 2001 From: Sawan Srivastava Date: Thu, 17 Jul 2025 15:13:25 -0700 Subject: [PATCH 04/38] Create basic createWeight to bulkScorer outline Signed-off-by: Sawan Srivastava --- .../approximate/ApproximateBooleanQuery.java | 110 +++++++++++++----- 1 file changed, 83 insertions(+), 27 deletions(-) diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java index 84aca430c7ec5..3801c17f0ccad 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java @@ -90,36 +90,92 @@ protected boolean canApproximate(SearchContext context) { } @Override - public ConstantScoreWeight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost){ - return new ConstantScoreWeight(this, boost) { - - /** - * @param ctx - * @return {@code true} if the object can be cached against a given leaf - */ - @Override - public boolean isCacheable(LeafReaderContext ctx) { - return false; + public ConstantScoreWeight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + // For single clause boolean queries, delegate to the clause's createWeight + if (clauses.size() == 1 && clauses.get(0).occur() != BooleanClause.Occur.MUST_NOT) { + BooleanClause singleClause = clauses.get(0); + Query clauseQuery = singleClause.query(); + + // If it's a scoring query, wrap it in a ConstantScoreQuery to ensure constant scoring + if (!(clauseQuery instanceof ConstantScoreQuery)) { + clauseQuery = new ConstantScoreQuery(clauseQuery); } - - /** - * Get a {@link ScorerSupplier}, which allows knowing the cost of the {@link Scorer} before - * building it. A scorer supplier for the same {@link LeafReaderContext} instance may be requested - * multiple times as part of a single search call. - * - *

Note: It must return null if the scorer is null. - * - * @param context the leaf reader context - * @return a {@link ScorerSupplier} providing the scorer, or null if scorer is null - * @throws IOException if an IOException occurs - * @see Scorer - * @see DefaultScorerSupplier - */ - @Override - public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { + + return (ConstantScoreWeight) clauseQuery.createWeight(searcher, scoreMode, boost); + } + + // For multi-clause boolean queries, create a custom weight + return new ApproximateBooleanWeight(searcher, scoreMode, boost); + } + + /** + * Custom Weight implementation for ApproximateBooleanQuery that handles multi-clause boolean queries. + * This is a basic implementation that behaves like a regular filter boolean query for now. + */ + private class ApproximateBooleanWeight extends ConstantScoreWeight { + private final Weight booleanWeight; + + public ApproximateBooleanWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + super(ApproximateBooleanQuery.this, boost); + // Create a weight for the underlying boolean query + this.booleanWeight = boolQuery.createWeight(searcher, scoreMode, boost); + } + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return false; + } + + @Override + public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { + // Get the scorer supplier from the underlying boolean weight + final ScorerSupplier booleanScorer = booleanWeight.scorerSupplier(context); + if (booleanScorer == null) { return null; } - }; + + // Return a wrapper scorer supplier that delegates to the boolean scorer + return new ScorerSupplier() { + @Override + public Scorer get(long leadCost) throws IOException { + Scorer scorer = booleanScorer.get(leadCost); + if (scorer == null) { + return null; + } + return new ConstantScoreScorer(ApproximateBooleanWeight.this, score(), scoreMode, scorer.iterator()); + } + + @Override + public long cost() { + return booleanScorer.cost(); + } + + @Override + public BulkScorer bulkScorer() throws IOException { + // For now, just delegate to the standard bulk scorer + // In the future, this is where we would implement our custom bulk scorer + return booleanScorer.bulkScorer(); + } + }; + } + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + ScorerSupplier scorerSupplier = scorerSupplier(context); + if (scorerSupplier == null) { + return null; + } + return scorerSupplier.get(Long.MAX_VALUE); + } + + @Override + public BulkScorer bulkScorer(LeafReaderContext context) throws IOException { + ScorerSupplier scorerSupplier = scorerSupplier(context); + if (scorerSupplier == null) { + return null; + } + return scorerSupplier.bulkScorer(); + } } @Override From 30013e4b013fd7407943d6c2614507199d69c058 Mon Sep 17 00:00:00 2001 From: Sawan Srivastava Date: Thu, 17 Jul 2025 23:31:39 +0000 Subject: [PATCH 05/38] custom weight implementation using default bulkscorer Signed-off-by: Sawan Srivastava --- .../index/query/BoolQueryBuilder.java | 7 +- .../approximate/ApproximateBooleanQuery.java | 66 +++++++++---------- .../approximate/ApproximateScoreQuery.java | 3 +- 3 files changed, 39 insertions(+), 37 deletions(-) diff --git a/server/src/main/java/org/opensearch/index/query/BoolQueryBuilder.java b/server/src/main/java/org/opensearch/index/query/BoolQueryBuilder.java index bb4a345374917..a9564db7b2c1d 100644 --- a/server/src/main/java/org/opensearch/index/query/BoolQueryBuilder.java +++ b/server/src/main/java/org/opensearch/index/query/BoolQueryBuilder.java @@ -343,9 +343,10 @@ protected Query doToQuery(QueryShardContext context) throws IOException { } // TODO: Figure out why multi-clause breaks testPhrasePrefix() in HighlighterWithAnalyzersTests.java - return ((BooleanQuery) query).clauses().size() == 1 - ? new ApproximateScoreQuery(query, new ApproximateBooleanQuery((BooleanQuery) query)) - : query; + // return ((BooleanQuery) query).clauses().size() == 1 + // ? new ApproximateScoreQuery(query, new ApproximateBooleanQuery((BooleanQuery) query)) + // : query; + return new ApproximateScoreQuery(query, new ApproximateBooleanQuery((BooleanQuery) query)); } private static void addBooleanClauses( diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java index 3801c17f0ccad..4638e5257cd3a 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java @@ -9,7 +9,19 @@ package org.opensearch.search.approximate; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.search.*; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.BulkScorer; +import org.apache.lucene.search.ConstantScoreQuery; +import org.apache.lucene.search.ConstantScoreScorer; +import org.apache.lucene.search.ConstantScoreWeight; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.ScorerSupplier; +import org.apache.lucene.search.Weight; import org.opensearch.search.internal.SearchContext; import java.io.IOException; @@ -24,7 +36,6 @@ public class ApproximateBooleanQuery extends ApproximateQuery { private final int size; private final List clauses; private ApproximateBooleanQuery booleanQuery; - public boolean isUnwrapped = false; public ApproximateBooleanQuery(BooleanQuery boolQuery) { this(boolQuery, SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO); @@ -86,6 +97,11 @@ protected boolean canApproximate(SearchContext context) { } } + if (clauses.size() > 1) { + // need to update + return true; + } + return false; } @@ -93,39 +109,40 @@ protected boolean canApproximate(SearchContext context) { public ConstantScoreWeight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { // For single clause boolean queries, delegate to the clause's createWeight if (clauses.size() == 1 && clauses.get(0).occur() != BooleanClause.Occur.MUST_NOT) { - BooleanClause singleClause = clauses.get(0); - Query clauseQuery = singleClause.query(); - + Query clauseQuery = clauses.get(0).query(); + // If it's a scoring query, wrap it in a ConstantScoreQuery to ensure constant scoring if (!(clauseQuery instanceof ConstantScoreQuery)) { clauseQuery = new ConstantScoreQuery(clauseQuery); } - + return (ConstantScoreWeight) clauseQuery.createWeight(searcher, scoreMode, boost); } - + // For multi-clause boolean queries, create a custom weight return new ApproximateBooleanWeight(searcher, scoreMode, boost); } - + /** * Custom Weight implementation for ApproximateBooleanQuery that handles multi-clause boolean queries. * This is a basic implementation that behaves like a regular filter boolean query for now. */ private class ApproximateBooleanWeight extends ConstantScoreWeight { private final Weight booleanWeight; - + private final ScoreMode scoreMode; + public ApproximateBooleanWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { super(ApproximateBooleanQuery.this, boost); // Create a weight for the underlying boolean query this.booleanWeight = boolQuery.createWeight(searcher, scoreMode, boost); + this.scoreMode = scoreMode; } - + @Override public boolean isCacheable(LeafReaderContext ctx) { return false; } - + @Override public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { // Get the scorer supplier from the underlying boolean weight @@ -133,7 +150,7 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti if (booleanScorer == null) { return null; } - + // Return a wrapper scorer supplier that delegates to the boolean scorer return new ScorerSupplier() { @Override @@ -142,14 +159,14 @@ public Scorer get(long leadCost) throws IOException { if (scorer == null) { return null; } - return new ConstantScoreScorer(ApproximateBooleanWeight.this, score(), scoreMode, scorer.iterator()); + return new ConstantScoreScorer(score(), scoreMode, scorer.iterator()); } - + @Override public long cost() { return booleanScorer.cost(); } - + @Override public BulkScorer bulkScorer() throws IOException { // For now, just delegate to the standard bulk scorer @@ -158,24 +175,7 @@ public BulkScorer bulkScorer() throws IOException { } }; } - - @Override - public Scorer scorer(LeafReaderContext context) throws IOException { - ScorerSupplier scorerSupplier = scorerSupplier(context); - if (scorerSupplier == null) { - return null; - } - return scorerSupplier.get(Long.MAX_VALUE); - } - - @Override - public BulkScorer bulkScorer(LeafReaderContext context) throws IOException { - ScorerSupplier scorerSupplier = scorerSupplier(context); - if (scorerSupplier == null) { - return null; - } - return scorerSupplier.bulkScorer(); - } + } @Override diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java index df5b2ec8ffaab..c597e56a9c372 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java @@ -59,7 +59,8 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { public void setContext(SearchContext context) { resolvedQuery = approximationQuery.canApproximate(context) ? approximationQuery : originalQuery; - if (resolvedQuery instanceof ApproximateBooleanQuery || resolvedQuery instanceof BooleanQuery) { + if (resolvedQuery instanceof ApproximateBooleanQuery && ((BooleanQuery) originalQuery).clauses().size() == 1 + || resolvedQuery instanceof BooleanQuery) { resolvedQuery = ApproximateBooleanQuery.unwrap(resolvedQuery); if (resolvedQuery instanceof ApproximateScoreQuery appxResolved) { appxResolved.setContext(context); From cc015702b682b22fe3802e3d067cf5222e4cdb10 Mon Sep 17 00:00:00 2001 From: Sawan Srivastava Date: Fri, 18 Jul 2025 17:03:38 +0000 Subject: [PATCH 06/38] updated canApproximate Signed-off-by: Sawan Srivastava --- .../approximate/ApproximateBooleanQuery.java | 13 +++--------- .../approximate/ApproximateScoreQuery.java | 20 +++++++++++++++++-- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java index 4638e5257cd3a..bf2aa1817b2dd 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java @@ -35,7 +35,6 @@ public class ApproximateBooleanQuery extends ApproximateQuery { public final BooleanQuery boolQuery; private final int size; private final List clauses; - private ApproximateBooleanQuery booleanQuery; public ApproximateBooleanQuery(BooleanQuery boolQuery) { this(boolQuery, SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO); @@ -47,8 +46,8 @@ protected ApproximateBooleanQuery(BooleanQuery boolQuery, int size) { this.clauses = boolQuery.clauses(); } - public ApproximateBooleanQuery getBooleanQuery() { - return booleanQuery; + public BooleanQuery getBooleanQuery() { + return boolQuery; } public Query getClauseQuery() { @@ -68,7 +67,6 @@ public static Query unwrap(Query unwrapBoolQuery) { @Override protected boolean canApproximate(SearchContext context) { - booleanQuery = this; if (context == null) { return false; } @@ -97,12 +95,7 @@ protected boolean canApproximate(SearchContext context) { } } - if (clauses.size() > 1) { - // need to update - return true; - } - - return false; + return clauses.size() > 1 && clauses.stream().allMatch(clause -> clause.occur() == BooleanClause.Occur.FILTER); } @Override diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java index c597e56a9c372..352cd3408aa67 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java @@ -59,12 +59,28 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { public void setContext(SearchContext context) { resolvedQuery = approximationQuery.canApproximate(context) ? approximationQuery : originalQuery; - if (resolvedQuery instanceof ApproximateBooleanQuery && ((BooleanQuery) originalQuery).clauses().size() == 1 - || resolvedQuery instanceof BooleanQuery) { + + boolean needsRewrite = false; + + if (resolvedQuery instanceof ApproximateBooleanQuery appxBool) { + if (appxBool.getBooleanQuery().clauses().size() == 1) { + // For single-clause boolean queries, unwrap and process as before + resolvedQuery = ApproximateBooleanQuery.unwrap(resolvedQuery); + if (resolvedQuery instanceof ApproximateScoreQuery appxResolved) { + appxResolved.setContext(context); + } + } + needsRewrite = true; + } else if (resolvedQuery instanceof BooleanQuery) { resolvedQuery = ApproximateBooleanQuery.unwrap(resolvedQuery); if (resolvedQuery instanceof ApproximateScoreQuery appxResolved) { appxResolved.setContext(context); } + needsRewrite = true; + } + + // Only rewrite boolean queries + if (needsRewrite) { try { resolvedQuery = resolvedQuery.rewrite(context.searcher()); } catch (IOException e) { From e2f69ba2cb8fad7a55abdc413d26c1eb21993d00 Mon Sep 17 00:00:00 2001 From: Sawan Srivastava Date: Fri, 18 Jul 2025 19:14:02 +0000 Subject: [PATCH 07/38] created basic skeletons of custom classes Signed-off-by: Sawan Srivastava --- .../approximate/ApproximateBooleanQuery.java | 12 ++ .../ApproximateBooleanScorerSupplier.java | 53 ++++++ .../ApproximateConjunctionDISI.java | 23 +++ .../ApproximateConjunctionScorer.java | 58 ++++++ .../search/approximate/ResumableDISI.java | 169 ++++++++++++++++++ 5 files changed, 315 insertions(+) create mode 100644 server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java create mode 100644 server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionDISI.java create mode 100644 server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionScorer.java create mode 100644 server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java index bf2aa1817b2dd..dca6118065d5a 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java @@ -136,6 +136,17 @@ public boolean isCacheable(LeafReaderContext ctx) { return false; } + // public BulkScorer bulkScorer(LeafReaderContext context) throws IOException { + // ScorerSupplier scorerSupplier = scorerSupplier(context); + // if (scorerSupplier == null) { + // // No docs match + // return null; + // } + // + // scorerSupplier.setTopLevelScoringClause(); + // return scorerSupplier.bulkScorer(); + // } + @Override public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { // Get the scorer supplier from the underlying boolean weight @@ -144,6 +155,7 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti return null; } + // return new ApproximateBooleanScorerSupplier(); // Return a wrapper scorer supplier that delegates to the boolean scorer return new ScorerSupplier() { @Override diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java new file mode 100644 index 0000000000000..cb6e3ac8ae59d --- /dev/null +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java @@ -0,0 +1,53 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.approximate; + +import org.apache.lucene.search.BulkScorer; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.ScorerSupplier; + +import java.io.IOException; + +public class ApproximateBooleanScorerSupplier extends ScorerSupplier { + + /** + * Get the {@link Scorer}. This may not return {@code null} and must be called at most once. + * + * @param leadCost Cost of the scorer that will be used in order to lead iteration. This can be + * interpreted as an upper bound of the number of times that {@link DocIdSetIterator#nextDoc}, + * {@link DocIdSetIterator#advance} and TwoPhaseIterator#matches will be called. Under + * doubt, pass {@link Long#MAX_VALUE}, which will produce a {@link Scorer} that has good + * iteration capabilities. + */ + @Override + public Scorer get(long leadCost) throws IOException { + return null; + } + + /** + * Optional method: Get a scorer that is optimized for bulk-scoring. The default implementation + * iterates matches from the {@link Scorer} but some queries can have more efficient approaches + * for matching all hits. + */ + public BulkScorer bulkScorer() throws IOException { + return null; + } + + /** + * Get an estimate of the {@link Scorer} that would be returned by {@link #get}. This may be a + * costly operation, so it should only be called if necessary. + * + * @see DocIdSetIterator#cost + */ + @Override + public long cost() { + return 0; + } +} diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionDISI.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionDISI.java new file mode 100644 index 0000000000000..6b5fef5876af2 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionDISI.java @@ -0,0 +1,23 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.approximate; + +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.FilterDocIdSetIterator; + +public class ApproximateConjunctionDISI extends FilterDocIdSetIterator { + /** + * Sole constructor. + * + * @param in + */ + public ApproximateConjunctionDISI(DocIdSetIterator in) { + super(in); + } +} diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionScorer.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionScorer.java new file mode 100644 index 0000000000000..4bcad6a58c1ea --- /dev/null +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionScorer.java @@ -0,0 +1,58 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.approximate; + +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Scorer; + +import java.io.IOException; + +public class ApproximateConjunctionScorer extends Scorer { + /** + * Returns the doc ID that is currently being scored. + */ + @Override + public int docID() { + return 0; + } + + /** + * Return a {@link DocIdSetIterator} over matching documents. + * + *

The returned iterator will either be positioned on {@code -1} if no documents have been + * scored yet, {@link DocIdSetIterator#NO_MORE_DOCS} if all documents have been scored already, or + * the last document id that has been scored otherwise. + * + *

The returned iterator is a view: calling this method several times will return iterators + * that have the same state. + */ + @Override + public DocIdSetIterator iterator() { + return null; + } + + /** + * Return the maximum score that documents between the last {@code target} that this iterator was + * {@link #advanceShallow(int) shallow-advanced} to included and {@code upTo} included. + * + * @param upTo + */ + @Override + public float getMaxScore(int upTo) throws IOException { + return 0; + } + + /** + * Returns the score of the current document matching the query. + */ + @Override + public float score() throws IOException { + return 0; + } +} diff --git a/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java b/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java new file mode 100644 index 0000000000000..1e28d740eecff --- /dev/null +++ b/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java @@ -0,0 +1,169 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.approximate; + +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.ScorerSupplier; + +import java.io.IOException; + +/** + * A resumable DocIdSetIterator that can be used to score documents in batches. + * This class wraps a ScorerSupplier and creates a new Scorer/DocIdSetIterator only when needed. + * It keeps track of the last document scored and can resume from that point when asked to score more documents. + */ +public class ResumableDISI extends DocIdSetIterator { + private static final int DEFAULT_BATCH_SIZE = 10_000; + + private final ScorerSupplier scorerSupplier; + private DocIdSetIterator currentDisi; + private int lastDocID = -1; + private int docsScored = 0; + private final int batchSize; + private boolean exhausted = false; + private long leadCost; + + /** + * Creates a new ResumableDISI with the default batch size of 10,000 documents. + * + * @param scorerSupplier The scorer supplier to get scorers from + */ + public ResumableDISI(ScorerSupplier scorerSupplier) { + this(scorerSupplier, DEFAULT_BATCH_SIZE); + } + + /** + * Creates a new ResumableDISI with the specified batch size. + * + * @param scorerSupplier The scorer supplier to get scorers from + * @param batchSize The number of documents to score in each batch + */ + public ResumableDISI(ScorerSupplier scorerSupplier, int batchSize) { + this.scorerSupplier = scorerSupplier; + this.batchSize = batchSize; + this.leadCost = Long.MAX_VALUE; // Start with max cost, will be adjusted later + } + + /** + * Initializes or resets the internal DocIdSetIterator. + * If this is the first call or we've reached the batch limit, a new DISI is created. + * Otherwise, the existing DISI is reused. + * + * @return The current DocIdSetIterator + * @throws IOException If there's an error getting the scorer + */ + private DocIdSetIterator getOrCreateDisi() throws IOException { + if (exhausted) { + return currentDisi; // Already exhausted, no need to create a new one + } + + if (currentDisi == null || docsScored >= batchSize) { + // If we've already scored some documents, adjust the lead cost + if (docsScored > 0) { + // Reduce the lead cost based on what we've already processed + leadCost = Math.max(scorerSupplier.cost() - docsScored, 0); + } + + // Get a new scorer and its iterator + Scorer scorer = scorerSupplier.get(leadCost); + currentDisi = scorer.iterator(); + + // If we have a last document ID, advance to the next one + if (lastDocID >= 0) { + currentDisi.advance(lastDocID + 1); + } + + // Reset the docs scored counter for this batch + docsScored = 0; + } + + return currentDisi; + } + + @Override + public int docID() { + if (currentDisi == null) { + return -1; + } + return currentDisi.docID(); + } + + @Override + public int nextDoc() throws IOException { + DocIdSetIterator disi = getOrCreateDisi(); + int doc = disi.nextDoc(); + + if (doc != NO_MORE_DOCS) { + lastDocID = doc; + docsScored++; + } else { + exhausted = true; + } + + return doc; + } + + @Override + public int advance(int target) throws IOException { + DocIdSetIterator disi = getOrCreateDisi(); + int doc = disi.advance(target); + + if (doc != NO_MORE_DOCS) { + lastDocID = doc; + docsScored++; + } else { + exhausted = true; + } + + return doc; + } + + @Override + public long cost() { + return scorerSupplier.cost(); + } + + /** + * Resets the iterator to start a new batch from the last document ID. + * This allows the caller to continue scoring from where it left off. + */ + public void resetForNextBatch() { + if (!exhausted) { + currentDisi = null; // Force creation of a new DISI on next call + } + } + + /** + * Returns the number of documents scored in the current batch. + * + * @return The number of documents scored + */ + public int getDocsScored() { + return docsScored; + } + + /** + * Returns whether this iterator has been exhausted. + * + * @return true if there are no more documents to score + */ + public boolean isExhausted() { + return exhausted; + } + + /** + * Returns the last document ID that was scored. + * + * @return The last document ID, or -1 if no documents have been scored + */ + public int getLastDocID() { + return lastDocID; + } +} From 7b8c505433dad7bed8de5c65b55dd469c1aeedb4 Mon Sep 17 00:00:00 2001 From: Sawan Srivastava Date: Fri, 18 Jul 2025 22:31:06 +0000 Subject: [PATCH 08/38] implemented BKDState in ResumableDISI Signed-off-by: Sawan Srivastava --- .../ApproximatePointRangeQuery.java | 2 + .../search/approximate/ResumableDISI.java | 59 +++++++++++++++---- 2 files changed, 49 insertions(+), 12 deletions(-) diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java index 046eb4dc1c86f..35e1982007a4f 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java @@ -58,6 +58,8 @@ public class ApproximatePointRangeQuery extends ApproximateQuery { public PointRangeQuery pointRangeQuery; private final Function valueToString; + private ResumableDISI.BKDState state; + public ApproximatePointRangeQuery( String field, byte[] lowerPoint, diff --git a/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java b/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java index 1e28d740eecff..bf3c2b3cf9327 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java +++ b/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java @@ -8,6 +8,7 @@ package org.opensearch.search.approximate; +import org.apache.lucene.index.PointValues; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.ScorerSupplier; @@ -17,18 +18,22 @@ /** * A resumable DocIdSetIterator that can be used to score documents in batches. * This class wraps a ScorerSupplier and creates a new Scorer/DocIdSetIterator only when needed. - * It keeps track of the last document scored and can resume from that point when asked to score more documents. + * It maintains state between calls to enable resuming from where it left off. + * + * This implementation is specifically designed for the approximation framework to enable + * early termination while preserving state between scoring cycles. */ public class ResumableDISI extends DocIdSetIterator { private static final int DEFAULT_BATCH_SIZE = 10_000; private final ScorerSupplier scorerSupplier; private DocIdSetIterator currentDisi; - private int lastDocID = -1; - private int docsScored = 0; private final int batchSize; private boolean exhausted = false; - private long leadCost; + + // State tracking + private int lastDocID = -1; + private int docsScored = 0; /** * Creates a new ResumableDISI with the default batch size of 10,000 documents. @@ -48,7 +53,6 @@ public ResumableDISI(ScorerSupplier scorerSupplier) { public ResumableDISI(ScorerSupplier scorerSupplier, int batchSize) { this.scorerSupplier = scorerSupplier; this.batchSize = batchSize; - this.leadCost = Long.MAX_VALUE; // Start with max cost, will be adjusted later } /** @@ -65,14 +69,8 @@ private DocIdSetIterator getOrCreateDisi() throws IOException { } if (currentDisi == null || docsScored >= batchSize) { - // If we've already scored some documents, adjust the lead cost - if (docsScored > 0) { - // Reduce the lead cost based on what we've already processed - leadCost = Math.max(scorerSupplier.cost() - docsScored, 0); - } - // Get a new scorer and its iterator - Scorer scorer = scorerSupplier.get(leadCost); + Scorer scorer = scorerSupplier.get(scorerSupplier.cost()); currentDisi = scorer.iterator(); // If we have a last document ID, advance to the next one @@ -166,4 +164,41 @@ public boolean isExhausted() { public int getLastDocID() { return lastDocID; } + + /** + * Class to track the state of BKD tree traversal. + */ + public static class BKDState { + private PointValues.PointTree currentTree; + private boolean isExhausted = false; + private long docCount = 0; + + public PointValues.PointTree getCurrentTree() { + return currentTree; + } + + public void setCurrentTree(PointValues.PointTree tree) { + if (tree != null) { + this.currentTree = tree.clone(); + } else { + this.currentTree = null; + } + } + + public boolean isExhausted() { + return isExhausted; + } + + public void setExhausted(boolean exhausted) { + this.isExhausted = exhausted; + } + + public long getDocCount() { + return docCount; + } + + public void setDocCount(long count) { + this.docCount = count; + } + } } From 05cb03ccf64af338ef394c3af0ce7221fff622a2 Mon Sep 17 00:00:00 2001 From: Sawan Srivastava Date: Sat, 19 Jul 2025 02:58:17 +0000 Subject: [PATCH 09/38] working ResumableDISI integration into ApproximatePointRangeQuery Signed-off-by: Sawan Srivastava --- .../ApproximatePointRangeQuery.java | 226 +++++++++++++++++- .../search/approximate/ResumableDISI.java | 9 + 2 files changed, 222 insertions(+), 13 deletions(-) diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java index 35e1982007a4f..7ff2777070818 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java @@ -106,6 +106,10 @@ public void setSortOrder(SortOrder sortOrder) { this.sortOrder = sortOrder; } + public void setBKDState(ResumableDISI.BKDState state) { + this.state = state; + } + @Override public Query rewrite(IndexSearcher indexSearcher) throws IOException { return super.rewrite(indexSearcher); @@ -257,35 +261,87 @@ private void intersectRight(PointValues.PointTree pointTree, PointValues.Interse // custom intersect visitor to walk the left of the tree public void intersectLeft(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree, long[] docCount) throws IOException { + // Check if we've already collected enough documents if (docCount[0] >= size) { + // If we have state, save the current tree as the next node to visit + if (state != null) { + state.setCurrentTree(pointTree); + state.setInProgress(true); + } return; } + PointValues.Relation r = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue()); if (r == PointValues.Relation.CELL_OUTSIDE_QUERY) { return; } + // Handle leaf nodes if (pointTree.moveToChild() == false) { if (r == PointValues.Relation.CELL_INSIDE_QUERY) { - pointTree.visitDocIDs(visitor); + // Save state before visiting docs if we're close to the limit + if (state != null && pointTree.size() + docCount[0] >= size) { + // Clone the tree before visiting docs + PointValues.PointTree nextNode = pointTree.clone(); + pointTree.visitDocIDs(visitor); + + // If we've hit the limit, save the next node + if (docCount[0] >= size) { + state.setCurrentTree(nextNode); + state.setInProgress(true); + return; + } + } else { + pointTree.visitDocIDs(visitor); + } } else { // CELL_CROSSES_QUERY - pointTree.visitDocValues(visitor); + // Save state before visiting docs if we're close to the limit + if (state != null && pointTree.size() + docCount[0] >= size) { + // Clone the tree before visiting docs + PointValues.PointTree nextNode = pointTree.clone(); + pointTree.visitDocValues(visitor); + + // If we've hit the limit, save the next node + if (docCount[0] >= size) { + state.setCurrentTree(nextNode); + state.setInProgress(true); + return; + } + } else { + pointTree.visitDocValues(visitor); + } } return; } + // For CELL_INSIDE_QUERY, check if we can skip right child if (r == PointValues.Relation.CELL_INSIDE_QUERY) { long leftSize = pointTree.size(); long needed = size - docCount[0]; if (leftSize >= needed) { - // Process only left child - intersectLeft(visitor, pointTree, docCount); + // Save state before processing left child if we're going to hit the limit + if (state != null && leftSize >= needed) { + // Clone the current position + PointValues.PointTree currentPos = pointTree.clone(); + + // Process left child + intersectLeft(visitor, pointTree, docCount); + + // If we've hit the limit, the state is already saved in the recursive call + if (docCount[0] >= size) { + return; + } + } else { + // Process only left child + intersectLeft(visitor, pointTree, docCount); + } pointTree.moveToParent(); return; } } + // We need both children - now clone right PointValues.PointTree rightChild = null; if (pointTree.moveToSibling()) { @@ -293,59 +349,150 @@ public void intersectLeft(PointValues.IntersectVisitor visitor, PointValues.Poin pointTree.moveToParent(); pointTree.moveToChild(); } - // Process both children: left first, then right if needed + + // Process left child first intersectLeft(visitor, pointTree, docCount); - if (docCount[0] < size && rightChild != null) { + + // If we've hit the limit, return (state is already saved in the recursive call) + if (docCount[0] >= size) { + return; + } + + // Process right child if needed + if (rightChild != null && docCount[0] < size) { intersectLeft(visitor, rightChild, docCount); } + pointTree.moveToParent(); } // custom intersect visitor to walk the right of tree (from rightmost leaf going left) public void intersectRight(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree, long[] docCount) throws IOException { + // Check if we've already collected enough documents if (docCount[0] >= size) { + // If we have state, save the current tree as the next node to visit + if (state != null) { + state.setCurrentTree(pointTree); + state.setInProgress(true); + } return; } + PointValues.Relation r = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue()); if (r == PointValues.Relation.CELL_OUTSIDE_QUERY) { return; } + // Handle leaf nodes if (pointTree.moveToChild() == false) { if (r == PointValues.Relation.CELL_INSIDE_QUERY) { - pointTree.visitDocIDs(visitor); + // Save state before visiting docs if we're close to the limit + if (state != null && pointTree.size() + docCount[0] >= size) { + // Clone the tree before visiting docs + PointValues.PointTree nextNode = pointTree.clone(); + pointTree.visitDocIDs(visitor); + + // If we've hit the limit, save the next node + if (docCount[0] >= size) { + state.setCurrentTree(nextNode); + state.setInProgress(true); + return; + } + } else { + pointTree.visitDocIDs(visitor); + } } else { // CELL_CROSSES_QUERY - pointTree.visitDocValues(visitor); + // Save state before visiting docs if we're close to the limit + if (state != null && pointTree.size() + docCount[0] >= size) { + // Clone the tree before visiting docs + PointValues.PointTree nextNode = pointTree.clone(); + pointTree.visitDocValues(visitor); + + // If we've hit the limit, save the next node + if (docCount[0] >= size) { + state.setCurrentTree(nextNode); + state.setInProgress(true); + return; + } + } else { + pointTree.visitDocValues(visitor); + } } return; } + // Internal node - get left child reference (we're at left child initially) PointValues.PointTree leftChild = pointTree.clone(); + // Move to right child if it exists boolean hasRightChild = pointTree.moveToSibling(); + // For CELL_INSIDE_QUERY, check if we can skip left child if (r == PointValues.Relation.CELL_INSIDE_QUERY && hasRightChild) { long rightSize = pointTree.size(); long needed = size - docCount[0]; if (rightSize >= needed) { - // Right child has all we need - only process right - intersectRight(visitor, pointTree, docCount); + // Save state before processing right child if we're going to hit the limit + if (state != null && rightSize >= needed) { + // Clone the current position + PointValues.PointTree currentPos = pointTree.clone(); + + // Process right child + intersectRight(visitor, pointTree, docCount); + + // If we've hit the limit, the state is already saved in the recursive call + if (docCount[0] >= size) { + return; + } + } else { + // Right child has all we need - only process right + intersectRight(visitor, pointTree, docCount); + } pointTree.moveToParent(); return; } } - // Process both children: right first (for DESC), then left if needed + + // Process right child first (for DESC) if (hasRightChild) { intersectRight(visitor, pointTree, docCount); + + // If we've hit the limit, return (state is already saved in the recursive call) + if (docCount[0] >= size) { + return; + } } + + // Process left child if needed if (docCount[0] < size) { intersectRight(visitor, leftChild, docCount); } + pointTree.moveToParent(); } + private void captureStateBeforeIntersect() { + // Save the current state before intersect + stateBeforeIntersect = state.getCurrentTree() != null ? state.getCurrentTree().clone() : null; + } + + private void updateStateAfterIntersect(long[] docCount) { + // If we've collected enough documents, we need to save state + if (docCount[0] >= size && stateBeforeIntersect != null) { + // We've collected enough documents, save the state from before the intersect + // This is a simplification - ideally we'd save the exact point where we stopped + state.setCurrentTree(stateBeforeIntersect); + } else { + // We've exhausted the tree + state.setExhausted(true); + } + } + + // Temporary variable to hold state during intersect + private PointValues.PointTree stateBeforeIntersect; + @Override public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { LeafReader reader = context.reader(); @@ -359,6 +506,19 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti if (size > values.size()) { return pointRangeQueryWeight.scorerSupplier(context); } else { + + if (state == null) { + state = new ResumableDISI.BKDState(); + } + + // Reset docCount if we're starting fresh + if (state.getCurrentTree() == null && !state.isExhausted()) { + docCount[0] = 0; + } else { + // Resume from where we left off + docCount[0] = state.getDocCount(); + } + if (sortOrder == null || sortOrder.equals(SortOrder.ASC)) { return new ScorerSupplier() { @@ -368,7 +528,27 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti @Override public Scorer get(long leadCost) throws IOException { - intersectLeft(values.getPointTree(), visitor, docCount); + // Check if we have a saved tree and we're not exhausted + if (state.getCurrentTree() == null && !state.isExhausted()) { + // First call - start from the root + state.setCurrentTree(values.getPointTree()); + } + + // Only process if we haven't collected enough documents and we're not exhausted + if (!state.isExhausted() && docCount[0] < size) { + // Call intersect with the current tree + // The state will be updated inside intersectLeft + intersectLeft(state.getCurrentTree(), visitor, docCount); + + // Update the state's docCount + state.setDocCount(docCount[0]); + + // If we didn't collect enough documents and we're not in progress, we've exhausted the tree + if (docCount[0] < size && !state.isInProgress()) { + state.setExhausted(true); + } + } + DocIdSetIterator iterator = result.build().iterator(); return new ConstantScoreScorer(score(), scoreMode, iterator); } @@ -396,7 +576,27 @@ public long cost() { @Override public Scorer get(long leadCost) throws IOException { - intersectRight(values.getPointTree(), visitor, docCount); + // Check if we have a saved tree and we're not exhausted + if (state.getCurrentTree() == null && !state.isExhausted()) { + // First call - start from the root + state.setCurrentTree(values.getPointTree()); + } + + // Only process if we haven't collected enough documents and we're not exhausted + if (!state.isExhausted() && docCount[0] < size) { + // Call intersect with the current tree + // The state will be updated inside intersectRight + intersectRight(state.getCurrentTree(), visitor, docCount); + + // Update the state's docCount + state.setDocCount(docCount[0]); + + // If we didn't collect enough documents and we're not in progress, we've exhausted the tree + if (docCount[0] < size && !state.isInProgress()) { + state.setExhausted(true); + } + } + DocIdSetIterator iterator = result.build().iterator(); return new ConstantScoreScorer(score(), scoreMode, iterator); } diff --git a/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java b/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java index bf3c2b3cf9327..c6d6a2ee4a89c 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java +++ b/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java @@ -172,6 +172,7 @@ public static class BKDState { private PointValues.PointTree currentTree; private boolean isExhausted = false; private long docCount = 0; + private boolean inProgress = false; public PointValues.PointTree getCurrentTree() { return currentTree; @@ -200,5 +201,13 @@ public long getDocCount() { public void setDocCount(long count) { this.docCount = count; } + + public boolean isInProgress() { + return inProgress; + } + + public void setInProgress(boolean inProgress) { + this.inProgress = inProgress; + } } } From a139c44a884040bb7340c44da525a0045ca7f225 Mon Sep 17 00:00:00 2001 From: Sawan Srivastava Date: Sat, 19 Jul 2025 03:59:47 +0000 Subject: [PATCH 10/38] =?UTF-8?q?Working=20multi=20clause=20boolean=20appr?= =?UTF-8?q?oximation=20(finished=20at=20the=20airport=20=F0=9F=98=83)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Sawan Srivastava --- .../approximate/ApproximateBooleanQuery.java | 55 ++---- .../ApproximateBooleanScorerSupplier.java | 187 ++++++++++++++++-- .../ApproximateConjunctionDISI.java | 123 +++++++++++- .../ApproximateConjunctionScorer.java | 61 +++--- 4 files changed, 334 insertions(+), 92 deletions(-) diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java index dca6118065d5a..631a93a9f2131 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java @@ -11,20 +11,18 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; -import org.apache.lucene.search.BulkScorer; import org.apache.lucene.search.ConstantScoreQuery; -import org.apache.lucene.search.ConstantScoreScorer; import org.apache.lucene.search.ConstantScoreWeight; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.ScoreMode; -import org.apache.lucene.search.Scorer; import org.apache.lucene.search.ScorerSupplier; import org.apache.lucene.search.Weight; import org.opensearch.search.internal.SearchContext; import java.io.IOException; +import java.util.ArrayList; import java.util.List; /** @@ -121,14 +119,16 @@ public ConstantScoreWeight createWeight(IndexSearcher searcher, ScoreMode scoreM * This is a basic implementation that behaves like a regular filter boolean query for now. */ private class ApproximateBooleanWeight extends ConstantScoreWeight { - private final Weight booleanWeight; private final ScoreMode scoreMode; + private final IndexSearcher searcher; + private final float boost; public ApproximateBooleanWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { super(ApproximateBooleanQuery.this, boost); // Create a weight for the underlying boolean query - this.booleanWeight = boolQuery.createWeight(searcher, scoreMode, boost); this.scoreMode = scoreMode; + this.searcher = searcher; + this.boost = boost; } @Override @@ -136,49 +136,16 @@ public boolean isCacheable(LeafReaderContext ctx) { return false; } - // public BulkScorer bulkScorer(LeafReaderContext context) throws IOException { - // ScorerSupplier scorerSupplier = scorerSupplier(context); - // if (scorerSupplier == null) { - // // No docs match - // return null; - // } - // - // scorerSupplier.setTopLevelScoringClause(); - // return scorerSupplier.bulkScorer(); - // } - @Override public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { - // Get the scorer supplier from the underlying boolean weight - final ScorerSupplier booleanScorer = booleanWeight.scorerSupplier(context); - if (booleanScorer == null) { - return null; + // For multi-clause boolean queries, create a custom scorer supplier + List clauseWeights = new ArrayList<>(clauses.size()); + for (BooleanClause clause : clauses) { + Weight weight = clause.query().createWeight(searcher, scoreMode, boost); + clauseWeights.add(weight); } - // return new ApproximateBooleanScorerSupplier(); - // Return a wrapper scorer supplier that delegates to the boolean scorer - return new ScorerSupplier() { - @Override - public Scorer get(long leadCost) throws IOException { - Scorer scorer = booleanScorer.get(leadCost); - if (scorer == null) { - return null; - } - return new ConstantScoreScorer(score(), scoreMode, scorer.iterator()); - } - - @Override - public long cost() { - return booleanScorer.cost(); - } - - @Override - public BulkScorer bulkScorer() throws IOException { - // For now, just delegate to the standard bulk scorer - // In the future, this is where we would implement our custom bulk scorer - return booleanScorer.bulkScorer(); - } - }; + return new ApproximateBooleanScorerSupplier(clauseWeights, scoreMode, boost, size, context); } } diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java index cb6e3ac8ae59d..48bdf55934ac2 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java @@ -8,46 +8,203 @@ package org.opensearch.search.approximate; +import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.BulkScorer; import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.LeafCollector; +import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.ScorerSupplier; +import org.apache.lucene.search.Weight; +import org.apache.lucene.util.Bits; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +/** + * A ScorerSupplier implementation for ApproximateBooleanQuery that creates resumable DocIdSetIterators + * for each clause and coordinates their usage in the boolean query context. + */ public class ApproximateBooleanScorerSupplier extends ScorerSupplier { + private final List clauseScorerSuppliers; + private final ScoreMode scoreMode; + private final float boost; + private final int threshold; + private final LeafReaderContext context; + private long cost = -1; + + /** + * Creates a new ApproximateBooleanScorerSupplier. + * + * @param clauseWeights The weights for each clause in the boolean query + * @param scoreMode The score mode + * @param boost The boost factor + * @param threshold The threshold for early termination + * @param context The leaf reader context + * @throws IOException If there's an error creating scorer suppliers + */ + public ApproximateBooleanScorerSupplier( + List clauseWeights, + ScoreMode scoreMode, + float boost, + int threshold, + LeafReaderContext context + ) throws IOException { + this.clauseScorerSuppliers = new ArrayList<>(clauseWeights.size()); + this.scoreMode = scoreMode; + this.boost = boost; + this.threshold = threshold; + this.context = context; + + // Create scorer suppliers for each clause + for (Weight clauseWeight : clauseWeights) { + ScorerSupplier supplier = clauseWeight.scorerSupplier(context); + if (supplier != null) { + clauseScorerSuppliers.add(supplier); + } + } + } /** * Get the {@link Scorer}. This may not return {@code null} and must be called at most once. * - * @param leadCost Cost of the scorer that will be used in order to lead iteration. This can be - * interpreted as an upper bound of the number of times that {@link DocIdSetIterator#nextDoc}, - * {@link DocIdSetIterator#advance} and TwoPhaseIterator#matches will be called. Under - * doubt, pass {@link Long#MAX_VALUE}, which will produce a {@link Scorer} that has good - * iteration capabilities. + * @param leadCost Cost of the scorer that will be used in order to lead iteration. */ @Override public Scorer get(long leadCost) throws IOException { - return null; + if (clauseScorerSuppliers.isEmpty()) { + return null; + } + + // Create ResumableDISIs for each clause + List clauseIterators = new ArrayList<>(clauseScorerSuppliers.size()); + for (ScorerSupplier supplier : clauseScorerSuppliers) { + ResumableDISI disi = new ResumableDISI(supplier); + clauseIterators.add(disi); + } + + // Create an ApproximateConjunctionScorer with the clause iterators + return new ApproximateConjunctionScorer(boost, scoreMode, clauseIterators); } /** - * Optional method: Get a scorer that is optimized for bulk-scoring. The default implementation - * iterates matches from the {@link Scorer} but some queries can have more efficient approaches - * for matching all hits. + * Get a scorer that is optimized for bulk-scoring. */ + @Override public BulkScorer bulkScorer() throws IOException { - return null; + if (clauseScorerSuppliers.isEmpty()) { + return null; + } + + // Create ResumableDISIs for each clause + List clauseIterators = new ArrayList<>(clauseScorerSuppliers.size()); + for (ScorerSupplier supplier : clauseScorerSuppliers) { + ResumableDISI disi = new ResumableDISI(supplier); + clauseIterators.add(disi); + } + + // Create an ApproximateBooleanBulkScorer with the clause iterators + return new ApproximateBooleanBulkScorer(clauseIterators, threshold); } /** - * Get an estimate of the {@link Scorer} that would be returned by {@link #get}. This may be a - * costly operation, so it should only be called if necessary. - * - * @see DocIdSetIterator#cost + * Get an estimate of the {@link Scorer} that would be returned by {@link #get}. */ @Override public long cost() { - return 0; + if (cost == -1) { + // Estimate cost as the minimum cost of all clauses (conjunction) + if (!clauseScorerSuppliers.isEmpty()) { + cost = Long.MAX_VALUE; + for (ScorerSupplier supplier : clauseScorerSuppliers) { + cost = Math.min(cost, supplier.cost()); + } + } else { + cost = 0; + } + } + return cost; + } + + /** + * A BulkScorer implementation that coordinates multiple ResumableDISIs to implement + * the circular scoring process described in the blog. + */ + private static class ApproximateBooleanBulkScorer extends BulkScorer { + private final List clauseIterators; + private final int threshold; + + public ApproximateBooleanBulkScorer(List clauseIterators, int threshold) { + this.clauseIterators = clauseIterators; + this.threshold = threshold; + } + + @Override + public int score(LeafCollector collector, Bits acceptDocs, int min, int max) throws IOException { + // Create an ApproximateConjunctionDISI to coordinate the clause iterators + ApproximateConjunctionDISI conjunctionDISI = new ApproximateConjunctionDISI(clauseIterators); + + // Create a scorer for the collector + ApproximateConjunctionScorer scorer = new ApproximateConjunctionScorer(1.0f, ScoreMode.COMPLETE, clauseIterators); + + // Set the scorer on the collector + collector.setScorer(scorer); + + // Track how many documents we've collected + int collected = 0; + int docID; + + // Collect documents until we reach the threshold or exhaust the iterator + while (collected < threshold && (docID = conjunctionDISI.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { + if (docID >= max) { + // We've reached the end of the range + return docID; + } + + if (docID >= min && (acceptDocs == null || acceptDocs.get(docID))) { + // Collect the document + collector.collect(docID); + collected++; + } + } + + // If we haven't collected enough documents and the iterator isn't exhausted, + // we need to rescore the clauses and continue + if (collected < threshold && !conjunctionDISI.isExhausted()) { + // Reset each clause iterator for the next batch + for (ResumableDISI disi : clauseIterators) { + disi.resetForNextBatch(); + } + + // Create a new conjunction DISI with the reset iterators + conjunctionDISI = new ApproximateConjunctionDISI(clauseIterators); + + // Continue collecting documents + while (collected < threshold && (docID = conjunctionDISI.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { + if (docID >= max) { + // We've reached the end of the range + return docID; + } + + if (docID >= min && (acceptDocs == null || acceptDocs.get(docID))) { + // Collect the document + collector.collect(docID); + collected++; + } + } + } + + // We've either collected enough documents or exhausted the iterator + return DocIdSetIterator.NO_MORE_DOCS; + } + + /** + * Same as {@link DocIdSetIterator#cost()} for bulk scorers. + */ + @Override + public long cost() { + return 0; + } } } diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionDISI.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionDISI.java index 6b5fef5876af2..263d26febf80b 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionDISI.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionDISI.java @@ -9,15 +9,126 @@ package org.opensearch.search.approximate; import org.apache.lucene.search.DocIdSetIterator; -import org.apache.lucene.search.FilterDocIdSetIterator; -public class ApproximateConjunctionDISI extends FilterDocIdSetIterator { +import java.io.IOException; +import java.util.List; + +/** + * A custom conjunction coordinator that understands resumable iterators. + * This class coordinates multiple ResumableDISIs to find documents that match all clauses. + */ +public class ApproximateConjunctionDISI extends DocIdSetIterator { + private final List iterators; + private final ResumableDISI lead; + private final ResumableDISI[] others; + private int doc = -1; + private boolean exhausted = false; + + /** + * Creates a new ApproximateConjunctionDISI. + * + * @param iterators The iterators to coordinate + */ + public ApproximateConjunctionDISI(List iterators) { + if (iterators.isEmpty()) { + throw new IllegalArgumentException("No iterators provided"); + } + + this.iterators = iterators; + + // Sort iterators by cost (ascending) + iterators.sort((a, b) -> Long.compare(a.cost(), b.cost())); + + // Use the cheapest iterator as the lead + this.lead = iterators.get(0); + + // Store the other iterators + this.others = new ResumableDISI[iterators.size() - 1]; + for (int i = 1; i < iterators.size(); i++) { + others[i - 1] = iterators.get(i); + } + } + + @Override + public int docID() { + return doc; + } + + @Override + public int nextDoc() throws IOException { + if (exhausted) { + return doc = NO_MORE_DOCS; + } + + // Advance the lead iterator + doc = lead.nextDoc(); + + if (doc == NO_MORE_DOCS) { + exhausted = true; + return doc; + } + + // Try to align all other iterators + return doNext(doc); + } + + @Override + public int advance(int target) throws IOException { + if (exhausted) { + return doc = NO_MORE_DOCS; + } + + // Advance the lead iterator + doc = lead.advance(target); + + if (doc == NO_MORE_DOCS) { + exhausted = true; + return doc; + } + + // Try to align all other iterators + return doNext(doc); + } + + /** + * Coordinates multiple iterators to find documents that match all clauses. + * This is similar to ConjunctionDISI.doNext() but adapted for resumable iterators. + */ + private int doNext(int doc) throws IOException { + advanceHead: for (;;) { + // Try to align all other iterators with the lead + for (ResumableDISI other : others) { + if (other.docID() < doc) { + final int next = other.advance(doc); + if (next > doc) { + // This iterator is ahead, advance the lead to catch up + doc = lead.advance(next); + if (doc == NO_MORE_DOCS) { + exhausted = true; + return this.doc = NO_MORE_DOCS; + } + continue advanceHead; + } + } + } + + // All iterators are aligned at the current doc + return this.doc = doc; + } + } + + @Override + public long cost() { + // Return the cost of the cheapest iterator + return lead.cost(); + } + /** - * Sole constructor. + * Returns whether this iterator has been exhausted. * - * @param in + * @return true if there are no more documents to score */ - public ApproximateConjunctionDISI(DocIdSetIterator in) { - super(in); + public boolean isExhausted() { + return exhausted; } } diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionScorer.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionScorer.java index 4bcad6a58c1ea..064df294670e4 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionScorer.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionScorer.java @@ -9,50 +9,57 @@ package org.opensearch.search.approximate; import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TwoPhaseIterator; import java.io.IOException; +import java.util.List; +/** + * A custom Scorer that manages an ApproximateConjunctionDISI. + * This class creates and manages an ApproximateConjunctionDISI to score documents + * that match all clauses in a boolean query. + */ public class ApproximateConjunctionScorer extends Scorer { - /** - * Returns the doc ID that is currently being scored. - */ - @Override - public int docID() { - return 0; - } + private final ApproximateConjunctionDISI approximateConjunctionDISI; + private final float score; /** - * Return a {@link DocIdSetIterator} over matching documents. + * Creates a new ApproximateConjunctionScorer. * - *

The returned iterator will either be positioned on {@code -1} if no documents have been - * scored yet, {@link DocIdSetIterator#NO_MORE_DOCS} if all documents have been scored already, or - * the last document id that has been scored otherwise. - * - *

The returned iterator is a view: calling this method several times will return iterators - * that have the same state. + * @param boost The boost factor + * @param scoreMode The score mode + * @param iterators The iterators to coordinate */ + public ApproximateConjunctionScorer(float boost, ScoreMode scoreMode, List iterators) { + // Scorer doesn't have a constructor that takes arguments + this.approximateConjunctionDISI = new ApproximateConjunctionDISI(iterators); + this.score = boost; + } + @Override public DocIdSetIterator iterator() { - return null; + return approximateConjunctionDISI; + } + + @Override + public float score() throws IOException { + return score; } - /** - * Return the maximum score that documents between the last {@code target} that this iterator was - * {@link #advanceShallow(int) shallow-advanced} to included and {@code upTo} included. - * - * @param upTo - */ @Override public float getMaxScore(int upTo) throws IOException { - return 0; + return score; } - /** - * Returns the score of the current document matching the query. - */ @Override - public float score() throws IOException { - return 0; + public int docID() { + return approximateConjunctionDISI.docID(); + } + + @Override + public TwoPhaseIterator twoPhaseIterator() { + return null; // No two-phase iteration needed for conjunction } } From 6f277851a90c110e57be0ead595a635011c3ff23 Mon Sep 17 00:00:00 2001 From: Sawan Srivastava Date: Sat, 26 Jul 2025 17:36:20 +0000 Subject: [PATCH 11/38] changed implementation to only create ResumableDISIs for approximated clauses Signed-off-by: Sawan Srivastava --- .../ApproximateBooleanScorerSupplier.java | 81 +++++++++++++------ .../ApproximateConjunctionDISI.java | 20 ++--- .../ApproximateConjunctionScorer.java | 8 +- .../search/approximate/ResumableDISI.java | 27 +++---- 4 files changed, 83 insertions(+), 53 deletions(-) diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java index 48bdf55934ac2..670869730794f 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java @@ -12,6 +12,7 @@ import org.apache.lucene.search.BulkScorer; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.LeafCollector; +import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.ScorerSupplier; @@ -27,7 +28,7 @@ * for each clause and coordinates their usage in the boolean query context. */ public class ApproximateBooleanScorerSupplier extends ScorerSupplier { - private final List clauseScorerSuppliers; + private final List clauseWeights; private final ScoreMode scoreMode; private final float boost; private final int threshold; @@ -51,17 +52,17 @@ public ApproximateBooleanScorerSupplier( int threshold, LeafReaderContext context ) throws IOException { - this.clauseScorerSuppliers = new ArrayList<>(clauseWeights.size()); + this.clauseWeights = new ArrayList<>(); this.scoreMode = scoreMode; this.boost = boost; this.threshold = threshold; this.context = context; - // Create scorer suppliers for each clause + // Store weights that have valid scorer suppliers for (Weight clauseWeight : clauseWeights) { ScorerSupplier supplier = clauseWeight.scorerSupplier(context); if (supplier != null) { - clauseScorerSuppliers.add(supplier); + this.clauseWeights.add(clauseWeight); } } } @@ -73,15 +74,25 @@ public ApproximateBooleanScorerSupplier( */ @Override public Scorer get(long leadCost) throws IOException { - if (clauseScorerSuppliers.isEmpty()) { + if (clauseWeights.isEmpty()) { return null; } - // Create ResumableDISIs for each clause - List clauseIterators = new ArrayList<>(clauseScorerSuppliers.size()); - for (ScorerSupplier supplier : clauseScorerSuppliers) { - ResumableDISI disi = new ResumableDISI(supplier); - clauseIterators.add(disi); + // Create appropriate iterators for each clause - ResumableDISI only for approximatable queries + List clauseIterators = new ArrayList<>(clauseWeights.size()); + for (Weight weight : clauseWeights) { + Query query = weight.getQuery(); + ScorerSupplier supplier = weight.scorerSupplier(context); + + if (query instanceof ApproximateQuery) { + // Use ResumableDISI for approximatable queries + ResumableDISI disi = new ResumableDISI(supplier); + clauseIterators.add(disi); + } else { + // Use regular DocIdSetIterator for non-approximatable queries + Scorer scorer = supplier.get(leadCost); + clauseIterators.add(scorer.iterator()); + } } // Create an ApproximateConjunctionScorer with the clause iterators @@ -93,15 +104,25 @@ public Scorer get(long leadCost) throws IOException { */ @Override public BulkScorer bulkScorer() throws IOException { - if (clauseScorerSuppliers.isEmpty()) { + if (clauseWeights.isEmpty()) { return null; } - // Create ResumableDISIs for each clause - List clauseIterators = new ArrayList<>(clauseScorerSuppliers.size()); - for (ScorerSupplier supplier : clauseScorerSuppliers) { - ResumableDISI disi = new ResumableDISI(supplier); - clauseIterators.add(disi); + // Create appropriate iterators for each clause - ResumableDISI only for approximatable queries + List clauseIterators = new ArrayList<>(clauseWeights.size()); + for (Weight weight : clauseWeights) { + Query query = weight.getQuery(); + ScorerSupplier supplier = weight.scorerSupplier(context); + + if (query instanceof ApproximateQuery) { + // Use ResumableDISI for approximatable queries + ResumableDISI disi = new ResumableDISI(supplier); + clauseIterators.add(disi); + } else { + // Use regular DocIdSetIterator for non-approximatable queries + Scorer scorer = supplier.get(supplier.cost()); + clauseIterators.add(scorer.iterator()); + } } // Create an ApproximateBooleanBulkScorer with the clause iterators @@ -115,10 +136,18 @@ public BulkScorer bulkScorer() throws IOException { public long cost() { if (cost == -1) { // Estimate cost as the minimum cost of all clauses (conjunction) - if (!clauseScorerSuppliers.isEmpty()) { + if (!clauseWeights.isEmpty()) { cost = Long.MAX_VALUE; - for (ScorerSupplier supplier : clauseScorerSuppliers) { - cost = Math.min(cost, supplier.cost()); + for (Weight weight : clauseWeights) { + try { + ScorerSupplier supplier = weight.scorerSupplier(context); + if (supplier != null) { + cost = Math.min(cost, supplier.cost()); + } + } catch (IOException e) { + // If we can't get the cost, use a default + cost = Math.min(cost, 1000); + } } } else { cost = 0; @@ -128,14 +157,14 @@ public long cost() { } /** - * A BulkScorer implementation that coordinates multiple ResumableDISIs to implement + * A BulkScorer implementation that coordinates multiple DocIdSetIterators (including ResumableDISIs) to implement * the circular scoring process described in the blog. */ private static class ApproximateBooleanBulkScorer extends BulkScorer { - private final List clauseIterators; + private final List clauseIterators; private final int threshold; - public ApproximateBooleanBulkScorer(List clauseIterators, int threshold) { + public ApproximateBooleanBulkScorer(List clauseIterators, int threshold) { this.clauseIterators = clauseIterators; this.threshold = threshold; } @@ -172,9 +201,11 @@ public int score(LeafCollector collector, Bits acceptDocs, int min, int max) thr // If we haven't collected enough documents and the iterator isn't exhausted, // we need to rescore the clauses and continue if (collected < threshold && !conjunctionDISI.isExhausted()) { - // Reset each clause iterator for the next batch - for (ResumableDISI disi : clauseIterators) { - disi.resetForNextBatch(); + // Reset only the ResumableDISI iterators for the next batch + for (DocIdSetIterator disi : clauseIterators) { + if (disi instanceof ResumableDISI) { + ((ResumableDISI) disi).resetForNextBatch(); + } } // Create a new conjunction DISI with the reset iterators diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionDISI.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionDISI.java index 263d26febf80b..1b03a7953694c 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionDISI.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionDISI.java @@ -14,22 +14,22 @@ import java.util.List; /** - * A custom conjunction coordinator that understands resumable iterators. - * This class coordinates multiple ResumableDISIs to find documents that match all clauses. + * A custom conjunction coordinator that understands both resumable and regular iterators. + * This class coordinates multiple DocIdSetIterators (which may include ResumableDISIs) to find documents that match all clauses. */ public class ApproximateConjunctionDISI extends DocIdSetIterator { - private final List iterators; - private final ResumableDISI lead; - private final ResumableDISI[] others; + private final List iterators; + private final DocIdSetIterator lead; + private final DocIdSetIterator[] others; private int doc = -1; private boolean exhausted = false; /** * Creates a new ApproximateConjunctionDISI. * - * @param iterators The iterators to coordinate + * @param iterators The iterators to coordinate (mix of ResumableDISI and regular DocIdSetIterator) */ - public ApproximateConjunctionDISI(List iterators) { + public ApproximateConjunctionDISI(List iterators) { if (iterators.isEmpty()) { throw new IllegalArgumentException("No iterators provided"); } @@ -43,7 +43,7 @@ public ApproximateConjunctionDISI(List iterators) { this.lead = iterators.get(0); // Store the other iterators - this.others = new ResumableDISI[iterators.size() - 1]; + this.others = new DocIdSetIterator[iterators.size() - 1]; for (int i = 1; i < iterators.size(); i++) { others[i - 1] = iterators.get(i); } @@ -92,12 +92,12 @@ public int advance(int target) throws IOException { /** * Coordinates multiple iterators to find documents that match all clauses. - * This is similar to ConjunctionDISI.doNext() but adapted for resumable iterators. + * This is similar to ConjunctionDISI.doNext() but adapted for mixed iterator types. */ private int doNext(int doc) throws IOException { advanceHead: for (;;) { // Try to align all other iterators with the lead - for (ResumableDISI other : others) { + for (DocIdSetIterator other : others) { if (other.docID() < doc) { final int next = other.advance(doc); if (next > doc) { diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionScorer.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionScorer.java index 064df294670e4..55141c4648e5c 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionScorer.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionScorer.java @@ -30,9 +30,9 @@ public class ApproximateConjunctionScorer extends Scorer { * * @param boost The boost factor * @param scoreMode The score mode - * @param iterators The iterators to coordinate + * @param iterators The iterators to coordinate (mix of ResumableDISI and regular DocIdSetIterator) */ - public ApproximateConjunctionScorer(float boost, ScoreMode scoreMode, List iterators) { + public ApproximateConjunctionScorer(float boost, ScoreMode scoreMode, List iterators) { // Scorer doesn't have a constructor that takes arguments this.approximateConjunctionDISI = new ApproximateConjunctionDISI(iterators); this.score = boost; @@ -45,12 +45,12 @@ public DocIdSetIterator iterator() { @Override public float score() throws IOException { - return score; + return 0.0f; } @Override public float getMaxScore(int upTo) throws IOException { - return score; + return 0.0f; } @Override diff --git a/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java b/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java index c6d6a2ee4a89c..f908c6c13d50f 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java +++ b/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java @@ -31,9 +31,10 @@ public class ResumableDISI extends DocIdSetIterator { private final int batchSize; private boolean exhausted = false; - // State tracking + // State tracking - track batches, not individual document movements private int lastDocID = -1; - private int docsScored = 0; + private int batchCount = 0; // How many batches we've created + private boolean needsNewBatch = true; // Whether we need to create a new batch /** * Creates a new ResumableDISI with the default batch size of 10,000 documents. @@ -57,8 +58,7 @@ public ResumableDISI(ScorerSupplier scorerSupplier, int batchSize) { /** * Initializes or resets the internal DocIdSetIterator. - * If this is the first call or we've reached the batch limit, a new DISI is created. - * Otherwise, the existing DISI is reused. + * Creates a new DISI only when we need a new batch. * * @return The current DocIdSetIterator * @throws IOException If there's an error getting the scorer @@ -68,7 +68,7 @@ private DocIdSetIterator getOrCreateDisi() throws IOException { return currentDisi; // Already exhausted, no need to create a new one } - if (currentDisi == null || docsScored >= batchSize) { + if (currentDisi == null || needsNewBatch) { // Get a new scorer and its iterator Scorer scorer = scorerSupplier.get(scorerSupplier.cost()); currentDisi = scorer.iterator(); @@ -78,8 +78,9 @@ private DocIdSetIterator getOrCreateDisi() throws IOException { currentDisi.advance(lastDocID + 1); } - // Reset the docs scored counter for this batch - docsScored = 0; + // Mark that we've created a new batch + batchCount++; + needsNewBatch = false; } return currentDisi; @@ -100,7 +101,6 @@ public int nextDoc() throws IOException { if (doc != NO_MORE_DOCS) { lastDocID = doc; - docsScored++; } else { exhausted = true; } @@ -115,7 +115,6 @@ public int advance(int target) throws IOException { if (doc != NO_MORE_DOCS) { lastDocID = doc; - docsScored++; } else { exhausted = true; } @@ -134,17 +133,17 @@ public long cost() { */ public void resetForNextBatch() { if (!exhausted) { - currentDisi = null; // Force creation of a new DISI on next call + needsNewBatch = true; // Mark that we need a new batch on next access } } /** - * Returns the number of documents scored in the current batch. + * Returns the number of batches created so far. * - * @return The number of documents scored + * @return The number of batches created */ - public int getDocsScored() { - return docsScored; + public int getBatchCount() { + return batchCount; } /** From 35efcde2e6ca7e8d0901ae9d7aa4ca35f2327fae Mon Sep 17 00:00:00 2001 From: Sawan Srivastava Date: Sat, 26 Jul 2025 17:54:23 +0000 Subject: [PATCH 12/38] implemented truly resumable scoring Signed-off-by: Sawan Srivastava --- .../ApproximateBooleanScorerSupplier.java | 62 +++++++++---------- .../ApproximateConjunctionDISI.java | 46 ++++++++++++++ .../search/approximate/ResumableDISI.java | 12 +++- 3 files changed, 85 insertions(+), 35 deletions(-) diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java index 670869730794f..88d934864af70 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java @@ -171,7 +171,7 @@ public ApproximateBooleanBulkScorer(List clauseIterators, int @Override public int score(LeafCollector collector, Bits acceptDocs, int min, int max) throws IOException { - // Create an ApproximateConjunctionDISI to coordinate the clause iterators + // Create an ApproximateConjunctionDISI once and reuse it (preserve conjunction state) ApproximateConjunctionDISI conjunctionDISI = new ApproximateConjunctionDISI(clauseIterators); // Create a scorer for the collector @@ -184,8 +184,34 @@ public int score(LeafCollector collector, Bits acceptDocs, int min, int max) thr int collected = 0; int docID; - // Collect documents until we reach the threshold or exhaust the iterator - while (collected < threshold && (docID = conjunctionDISI.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { + // Continue collecting until we reach the threshold + while (collected < threshold) { + // Get the next document from the conjunction + docID = conjunctionDISI.nextDoc(); + + if (docID == DocIdSetIterator.NO_MORE_DOCS) { + // No more documents in current state - try to expand ResumableDISIs + boolean anyExpanded = false; + for (DocIdSetIterator disi : clauseIterators) { + if (disi instanceof ResumableDISI) { + ResumableDISI resumableDISI = (ResumableDISI) disi; + if (!resumableDISI.isExhausted()) { + resumableDISI.resetForNextBatch(); // This expands the document set + anyExpanded = true; + } + } + } + + // If no ResumableDISIs were expanded, we're truly done + if (!anyExpanded) { + break; + } + + // Reset the conjunction so it can continue with expanded iterators + conjunctionDISI.resetAfterExpansion(); + continue; + } + if (docID >= max) { // We've reached the end of the range return docID; @@ -198,35 +224,7 @@ public int score(LeafCollector collector, Bits acceptDocs, int min, int max) thr } } - // If we haven't collected enough documents and the iterator isn't exhausted, - // we need to rescore the clauses and continue - if (collected < threshold && !conjunctionDISI.isExhausted()) { - // Reset only the ResumableDISI iterators for the next batch - for (DocIdSetIterator disi : clauseIterators) { - if (disi instanceof ResumableDISI) { - ((ResumableDISI) disi).resetForNextBatch(); - } - } - - // Create a new conjunction DISI with the reset iterators - conjunctionDISI = new ApproximateConjunctionDISI(clauseIterators); - - // Continue collecting documents - while (collected < threshold && (docID = conjunctionDISI.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { - if (docID >= max) { - // We've reached the end of the range - return docID; - } - - if (docID >= min && (acceptDocs == null || acceptDocs.get(docID))) { - // Collect the document - collector.collect(docID); - collected++; - } - } - } - - // We've either collected enough documents or exhausted the iterator + // We've either collected enough documents or exhausted all possibilities return DocIdSetIterator.NO_MORE_DOCS; } diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionDISI.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionDISI.java index 1b03a7953694c..628f1ea3b2984 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionDISI.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionDISI.java @@ -22,6 +22,7 @@ public class ApproximateConjunctionDISI extends DocIdSetIterator { private final DocIdSetIterator lead; private final DocIdSetIterator[] others; private int doc = -1; + private int lastValidDoc = -1; // Track the last valid document before NO_MORE_DOCS private boolean exhausted = false; /** @@ -64,6 +65,11 @@ public int nextDoc() throws IOException { doc = lead.nextDoc(); if (doc == NO_MORE_DOCS) { + // Before marking as exhausted, check if any ResumableDISI can be expanded + if (canExpandAnyResumableDISI()) { + // Don't mark as exhausted yet - caller can expand and try again + return doc; + } exhausted = true; return doc; } @@ -82,6 +88,11 @@ public int advance(int target) throws IOException { doc = lead.advance(target); if (doc == NO_MORE_DOCS) { + // Before marking as exhausted, check if any ResumableDISI can be expanded + if (canExpandAnyResumableDISI()) { + // Don't mark as exhausted yet - caller can expand and try again + return doc; + } exhausted = true; return doc; } @@ -90,6 +101,35 @@ public int advance(int target) throws IOException { return doNext(doc); } + /** + * Check if any ResumableDISI in the iterators can be expanded (not exhausted) + */ + private boolean canExpandAnyResumableDISI() { + for (DocIdSetIterator iterator : iterators) { + if (iterator instanceof ResumableDISI) { + ResumableDISI resumableDISI = (ResumableDISI) iterator; + if (!resumableDISI.isExhausted()) { + return true; + } + } + } + return false; + } + + /** + * Reset the exhausted state so the conjunction can continue after ResumableDISIs are expanded. + * Preserves the current document position to avoid reprocessing documents. + */ + public void resetAfterExpansion() { + // Only reset if we're not truly exhausted (i.e., some ResumableDISI was expanded) + if (canExpandAnyResumableDISI()) { + exhausted = false; + // Set doc to the last valid document we processed + // The next call to nextDoc() will advance from this position + doc = lastValidDoc; + } + } + /** * Coordinates multiple iterators to find documents that match all clauses. * This is similar to ConjunctionDISI.doNext() but adapted for mixed iterator types. @@ -104,6 +144,11 @@ private int doNext(int doc) throws IOException { // This iterator is ahead, advance the lead to catch up doc = lead.advance(next); if (doc == NO_MORE_DOCS) { + // Before marking as exhausted, check if any ResumableDISI can be expanded + if (canExpandAnyResumableDISI()) { + // Don't mark as exhausted yet - caller can expand and try again + return this.doc = NO_MORE_DOCS; + } exhausted = true; return this.doc = NO_MORE_DOCS; } @@ -113,6 +158,7 @@ private int doNext(int doc) throws IOException { } // All iterators are aligned at the current doc + lastValidDoc = doc; // Remember this valid document return this.doc = doc; } } diff --git a/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java b/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java index f908c6c13d50f..15c9c03a58aae 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java +++ b/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java @@ -58,7 +58,8 @@ public ResumableDISI(ScorerSupplier scorerSupplier, int batchSize) { /** * Initializes or resets the internal DocIdSetIterator. - * Creates a new DISI only when we need a new batch. + * For approximatable queries, this leverages their existing resumable mechanism. + * For non-approximatable queries, this creates new scorers as needed. * * @return The current DocIdSetIterator * @throws IOException If there's an error getting the scorer @@ -70,12 +71,17 @@ private DocIdSetIterator getOrCreateDisi() throws IOException { if (currentDisi == null || needsNewBatch) { // Get a new scorer and its iterator + // For approximatable queries, the scorer supplier will handle resumable state internally Scorer scorer = scorerSupplier.get(scorerSupplier.cost()); currentDisi = scorer.iterator(); - // If we have a last document ID, advance to the next one + // For non-approximatable queries, we need to advance past the last document + // For approximatable queries, they handle this internally via their BKD state if (lastDocID >= 0) { - currentDisi.advance(lastDocID + 1); + // Check if we need to advance (for non-approximatable queries) + if (currentDisi.docID() <= lastDocID) { + currentDisi.advance(lastDocID + 1); + } } // Mark that we've created a new batch From d91f825a3869c20af141e71463d90dc1dab62eae Mon Sep 17 00:00:00 2001 From: Sawan Srivastava Date: Sat, 26 Jul 2025 18:01:08 +0000 Subject: [PATCH 13/38] enabled resumableDISI expansions for multiple clauses Signed-off-by: Sawan Srivastava --- .../ApproximateConjunctionDISI.java | 25 ++++++++++++++++--- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionDISI.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionDISI.java index 628f1ea3b2984..ee384647327b2 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionDISI.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionDISI.java @@ -118,14 +118,31 @@ private boolean canExpandAnyResumableDISI() { /** * Reset the exhausted state so the conjunction can continue after ResumableDISIs are expanded. - * Preserves the current document position to avoid reprocessing documents. + * Ensures all iterators are positioned correctly to continue from where we left off. */ - public void resetAfterExpansion() { + public void resetAfterExpansion() throws IOException { // Only reset if we're not truly exhausted (i.e., some ResumableDISI was expanded) if (canExpandAnyResumableDISI()) { exhausted = false; - // Set doc to the last valid document we processed - // The next call to nextDoc() will advance from this position + + // Position all iterators to continue from after the last valid document + // This ensures we don't reprocess documents we've already seen + int targetDoc = lastValidDoc + 1; + + // Advance the lead iterator to the target position + if (lead.docID() < targetDoc) { + lead.advance(targetDoc); + } + + // Advance all other iterators to the target position + for (DocIdSetIterator other : others) { + if (other.docID() < targetDoc) { + other.advance(targetDoc); + } + } + + // Set doc to lastValidDoc so the next nextDoc() call will advance from there + // This prevents reprocessing documents we've already seen doc = lastValidDoc; } } From b27ebc4c2f25ba35b88a2fbef93c4fee09b7bd37 Mon Sep 17 00:00:00 2001 From: Sawan Srivastava Date: Mon, 28 Jul 2025 18:23:24 +0000 Subject: [PATCH 14/38] potential fix for missing docs Signed-off-by: Sawan Srivastava --- .../approximate/ApproximateBooleanScorerSupplier.java | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java index 88d934864af70..6362af70213ea 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java @@ -175,7 +175,7 @@ public int score(LeafCollector collector, Bits acceptDocs, int min, int max) thr ApproximateConjunctionDISI conjunctionDISI = new ApproximateConjunctionDISI(clauseIterators); // Create a scorer for the collector - ApproximateConjunctionScorer scorer = new ApproximateConjunctionScorer(1.0f, ScoreMode.COMPLETE, clauseIterators); + ApproximateConjunctionScorer scorer = new ApproximateConjunctionScorer(0.0f, ScoreMode.COMPLETE, clauseIterators); // Set the scorer on the collector collector.setScorer(scorer); @@ -209,7 +209,14 @@ public int score(LeafCollector collector, Bits acceptDocs, int min, int max) thr // Reset the conjunction so it can continue with expanded iterators conjunctionDISI.resetAfterExpansion(); - continue; + + // After expansion, check if we're already positioned on a valid document + docID = conjunctionDISI.docID(); + if (docID == DocIdSetIterator.NO_MORE_DOCS) { + // Still no document, try nextDoc() in the next iteration + continue; + } + // Fall through to process the current document } if (docID >= max) { From 972ce140f70826460a52d8fce4a8575460250dd1 Mon Sep 17 00:00:00 2001 From: Sawan Srivastava Date: Mon, 28 Jul 2025 19:59:11 +0000 Subject: [PATCH 15/38] more like Lucene's ConjunctionDISI Signed-off-by: Sawan Srivastava --- .../ApproximateBooleanScorerSupplier.java | 9 +- .../ApproximateConjunctionDISI.java | 210 +++++++----------- .../ApproximateConjunctionScorer.java | 13 ++ 3 files changed, 96 insertions(+), 136 deletions(-) diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java index 6362af70213ea..119e86f136cae 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java @@ -174,8 +174,8 @@ public int score(LeafCollector collector, Bits acceptDocs, int min, int max) thr // Create an ApproximateConjunctionDISI once and reuse it (preserve conjunction state) ApproximateConjunctionDISI conjunctionDISI = new ApproximateConjunctionDISI(clauseIterators); - // Create a scorer for the collector - ApproximateConjunctionScorer scorer = new ApproximateConjunctionScorer(0.0f, ScoreMode.COMPLETE, clauseIterators); + // Create a scorer for the collector that reuses the same conjunction + ApproximateConjunctionScorer scorer = new ApproximateConjunctionScorer(0.0f, ScoreMode.COMPLETE, conjunctionDISI); // Set the scorer on the collector collector.setScorer(scorer); @@ -208,8 +208,8 @@ public int score(LeafCollector collector, Bits acceptDocs, int min, int max) thr } // Reset the conjunction so it can continue with expanded iterators - conjunctionDISI.resetAfterExpansion(); - +// conjunctionDISI.resetAfterExpansion(); + // After expansion, check if we're already positioned on a valid document docID = conjunctionDISI.docID(); if (docID == DocIdSetIterator.NO_MORE_DOCS) { @@ -231,6 +231,7 @@ public int score(LeafCollector collector, Bits acceptDocs, int min, int max) thr } } + System.out.println(collected); // We've either collected enough documents or exhausted all possibilities return DocIdSetIterator.NO_MORE_DOCS; } diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionDISI.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionDISI.java index ee384647327b2..8bc95f4d8ed83 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionDISI.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionDISI.java @@ -14,184 +14,130 @@ import java.util.List; /** - * A custom conjunction coordinator that understands both resumable and regular iterators. - * This class coordinates multiple DocIdSetIterators (which may include ResumableDISIs) to find documents that match all clauses. + * A conjunction of DocIdSetIterators with support for ResumableDISI expansion. + * Closely mirrors Lucene's ConjunctionDISI architecture with lead1, lead2, and others. */ public class ApproximateConjunctionDISI extends DocIdSetIterator { - private final List iterators; - private final DocIdSetIterator lead; - private final DocIdSetIterator[] others; - private int doc = -1; - private int lastValidDoc = -1; // Track the last valid document before NO_MORE_DOCS - private boolean exhausted = false; - /** - * Creates a new ApproximateConjunctionDISI. - * - * @param iterators The iterators to coordinate (mix of ResumableDISI and regular DocIdSetIterator) - */ - public ApproximateConjunctionDISI(List iterators) { - if (iterators.isEmpty()) { - throw new IllegalArgumentException("No iterators provided"); - } + final DocIdSetIterator lead1, lead2; + final DocIdSetIterator[] others; - this.iterators = iterators; + private final List allIterators; - // Sort iterators by cost (ascending) - iterators.sort((a, b) -> Long.compare(a.cost(), b.cost())); + public ApproximateConjunctionDISI(List iterators) { + if (iterators.size() < 2) { + throw new IllegalArgumentException("Cannot make a ConjunctionDISI of less than 2 iterators"); + } - // Use the cheapest iterator as the lead - this.lead = iterators.get(0); + this.allIterators = iterators; - // Store the other iterators - this.others = new DocIdSetIterator[iterators.size() - 1]; - for (int i = 1; i < iterators.size(); i++) { - others[i - 1] = iterators.get(i); - } + // Follow Lucene's exact structure + this.lead1 = iterators.get(0); + this.lead2 = iterators.get(1); + this.others = iterators.subList(2, iterators.size()).toArray(new DocIdSetIterator[0]); } @Override public int docID() { - return doc; + return lead1.docID(); } @Override public int nextDoc() throws IOException { - if (exhausted) { - return doc = NO_MORE_DOCS; - } - - // Advance the lead iterator - doc = lead.nextDoc(); - - if (doc == NO_MORE_DOCS) { - // Before marking as exhausted, check if any ResumableDISI can be expanded - if (canExpandAnyResumableDISI()) { - // Don't mark as exhausted yet - caller can expand and try again - return doc; - } - exhausted = true; - return doc; - } - - // Try to align all other iterators - return doNext(doc); + return doNext(lead1.nextDoc()); } @Override public int advance(int target) throws IOException { - if (exhausted) { - return doc = NO_MORE_DOCS; - } - - // Advance the lead iterator - doc = lead.advance(target); - - if (doc == NO_MORE_DOCS) { - // Before marking as exhausted, check if any ResumableDISI can be expanded - if (canExpandAnyResumableDISI()) { - // Don't mark as exhausted yet - caller can expand and try again - return doc; - } - exhausted = true; - return doc; - } - - // Try to align all other iterators - return doNext(doc); + return doNext(lead1.advance(target)); } /** - * Check if any ResumableDISI in the iterators can be expanded (not exhausted) + * Core conjunction logic adapted from Lucene's ConjunctionDISI.doNext() + * with resumable expansion support. */ - private boolean canExpandAnyResumableDISI() { - for (DocIdSetIterator iterator : iterators) { - if (iterator instanceof ResumableDISI) { - ResumableDISI resumableDISI = (ResumableDISI) iterator; - if (!resumableDISI.isExhausted()) { - return true; + private int doNext(int doc) throws IOException { + advanceHead: + for (; ; ) { + // Handle NO_MORE_DOCS with resumable expansion + if (doc == NO_MORE_DOCS) { + if (tryExpandResumableDISIs()) { + // After expansion, get the next document from lead1 + doc = lead1.nextDoc(); + if (doc == NO_MORE_DOCS) { + return NO_MORE_DOCS; // Truly exhausted + } + // Continue with the new document + } else { + return NO_MORE_DOCS; // No expansion possible } } - } - return false; - } - /** - * Reset the exhausted state so the conjunction can continue after ResumableDISIs are expanded. - * Ensures all iterators are positioned correctly to continue from where we left off. - */ - public void resetAfterExpansion() throws IOException { - // Only reset if we're not truly exhausted (i.e., some ResumableDISI was expanded) - if (canExpandAnyResumableDISI()) { - exhausted = false; - - // Position all iterators to continue from after the last valid document - // This ensures we don't reprocess documents we've already seen - int targetDoc = lastValidDoc + 1; - - // Advance the lead iterator to the target position - if (lead.docID() < targetDoc) { - lead.advance(targetDoc); - } - - // Advance all other iterators to the target position - for (DocIdSetIterator other : others) { - if (other.docID() < targetDoc) { - other.advance(targetDoc); + // Find agreement between the two iterators with the lower costs + // We special case them because they do not need the + // 'other.docID() < doc' check that the 'others' iterators need + final int next2 = lead2.advance(doc); + if (next2 != doc) { + doc = lead1.advance(next2); + if (next2 != doc) { + continue; } } - - // Set doc to lastValidDoc so the next nextDoc() call will advance from there - // This prevents reprocessing documents we've already seen - doc = lastValidDoc; - } - } - /** - * Coordinates multiple iterators to find documents that match all clauses. - * This is similar to ConjunctionDISI.doNext() but adapted for mixed iterator types. - */ - private int doNext(int doc) throws IOException { - advanceHead: for (;;) { - // Try to align all other iterators with the lead + // Then find agreement with other iterators for (DocIdSetIterator other : others) { + // other.docID() may already be equal to doc if we "continued advanceHead" + // on the previous iteration and the advance on the lead exactly matched. if (other.docID() < doc) { final int next = other.advance(doc); + if (next > doc) { - // This iterator is ahead, advance the lead to catch up - doc = lead.advance(next); - if (doc == NO_MORE_DOCS) { - // Before marking as exhausted, check if any ResumableDISI can be expanded - if (canExpandAnyResumableDISI()) { - // Don't mark as exhausted yet - caller can expand and try again - return this.doc = NO_MORE_DOCS; - } - exhausted = true; - return this.doc = NO_MORE_DOCS; - } + // iterator beyond the current doc - advance lead and continue to the new highest doc. + doc = lead1.advance(next); continue advanceHead; } } } - // All iterators are aligned at the current doc - lastValidDoc = doc; // Remember this valid document - return this.doc = doc; + // Success - all iterators are on the same doc + return doc; } } + /** + * Try to expand ResumableDISIs when we hit NO_MORE_DOCS + * @return true if any ResumableDISI was expanded + */ + private boolean tryExpandResumableDISIs() throws IOException { + boolean anyExpanded = false; + + // Check all iterators for expansion + for (DocIdSetIterator iterator : allIterators) { + if (iterator instanceof ResumableDISI) { + ResumableDISI resumable = (ResumableDISI) iterator; + if (!resumable.isExhausted()) { + resumable.resetForNextBatch(); + anyExpanded = true; + } + } + } + + return anyExpanded; + } + @Override public long cost() { - // Return the cost of the cheapest iterator - return lead.cost(); + long minCost = Long.MAX_VALUE; + for (DocIdSetIterator iterator : allIterators) { + minCost = Math.min(minCost, iterator.cost()); + } + return minCost; } /** - * Returns whether this iterator has been exhausted. - * - * @return true if there are no more documents to score + * Reset method for compatibility (no longer needed with new architecture) */ - public boolean isExhausted() { - return exhausted; + public void resetAfterExpansion() throws IOException { + // No-op - expansion is now handled directly in doNext() } } diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionScorer.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionScorer.java index 55141c4648e5c..5e27552f07bec 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionScorer.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionScorer.java @@ -38,6 +38,19 @@ public ApproximateConjunctionScorer(float boost, ScoreMode scoreMode, List Date: Mon, 28 Jul 2025 22:18:07 +0000 Subject: [PATCH 16/38] Make BKDState a ScorerSupplier instance variable Signed-off-by: Sawan Srivastava --- .../ApproximatePointRangeQuery.java | 263 +++++++++--------- 1 file changed, 134 insertions(+), 129 deletions(-) diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java index 7ff2777070818..65b416315a051 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java @@ -58,7 +58,9 @@ public class ApproximatePointRangeQuery extends ApproximateQuery { public PointRangeQuery pointRangeQuery; private final Function valueToString; - private ResumableDISI.BKDState state; + // Remove shared state field to avoid concurrency issues + // BKD state is now passed as parameter to intersect methods + // private ResumableDISI.BKDState state; public ApproximatePointRangeQuery( String field, @@ -107,7 +109,8 @@ public void setSortOrder(SortOrder sortOrder) { } public void setBKDState(ResumableDISI.BKDState state) { - this.state = state; + // This method is no longer used since state is now per-shard + // Keeping for compatibility but it's a no-op } @Override @@ -246,27 +249,27 @@ private boolean checkValidPointValues(PointValues values) throws IOException { return true; } - private void intersectLeft(PointValues.PointTree pointTree, PointValues.IntersectVisitor visitor, long[] docCount) + private void intersectLeft(PointValues.PointTree pointTree, PointValues.IntersectVisitor visitor, long[] docCount, ResumableDISI.BKDState bkdState) throws IOException { - intersectLeft(visitor, pointTree, docCount); + intersectLeft(visitor, pointTree, docCount, bkdState); assert pointTree.moveToParent() == false; } - private void intersectRight(PointValues.PointTree pointTree, PointValues.IntersectVisitor visitor, long[] docCount) + private void intersectRight(PointValues.PointTree pointTree, PointValues.IntersectVisitor visitor, long[] docCount, ResumableDISI.BKDState bkdState) throws IOException { - intersectRight(visitor, pointTree, docCount); + intersectRight(visitor, pointTree, docCount, bkdState); assert pointTree.moveToParent() == false; } // custom intersect visitor to walk the left of the tree - public void intersectLeft(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree, long[] docCount) + public void intersectLeft(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree, long[] docCount, ResumableDISI.BKDState bkdState) throws IOException { // Check if we've already collected enough documents if (docCount[0] >= size) { // If we have state, save the current tree as the next node to visit - if (state != null) { - state.setCurrentTree(pointTree); - state.setInProgress(true); + if (bkdState != null) { + bkdState.setCurrentTree(pointTree); + bkdState.setInProgress(true); } return; } @@ -280,15 +283,15 @@ public void intersectLeft(PointValues.IntersectVisitor visitor, PointValues.Poin if (pointTree.moveToChild() == false) { if (r == PointValues.Relation.CELL_INSIDE_QUERY) { // Save state before visiting docs if we're close to the limit - if (state != null && pointTree.size() + docCount[0] >= size) { + if (bkdState != null && pointTree.size() + docCount[0] >= size) { // Clone the tree before visiting docs PointValues.PointTree nextNode = pointTree.clone(); pointTree.visitDocIDs(visitor); // If we've hit the limit, save the next node if (docCount[0] >= size) { - state.setCurrentTree(nextNode); - state.setInProgress(true); + bkdState.setCurrentTree(nextNode); + bkdState.setInProgress(true); return; } } else { @@ -297,15 +300,15 @@ public void intersectLeft(PointValues.IntersectVisitor visitor, PointValues.Poin } else { // CELL_CROSSES_QUERY // Save state before visiting docs if we're close to the limit - if (state != null && pointTree.size() + docCount[0] >= size) { + if (bkdState != null && pointTree.size() + docCount[0] >= size) { // Clone the tree before visiting docs PointValues.PointTree nextNode = pointTree.clone(); pointTree.visitDocValues(visitor); // If we've hit the limit, save the next node if (docCount[0] >= size) { - state.setCurrentTree(nextNode); - state.setInProgress(true); + bkdState.setCurrentTree(nextNode); + bkdState.setInProgress(true); return; } } else { @@ -322,12 +325,12 @@ public void intersectLeft(PointValues.IntersectVisitor visitor, PointValues.Poin if (leftSize >= needed) { // Save state before processing left child if we're going to hit the limit - if (state != null && leftSize >= needed) { + if (bkdState != null && leftSize >= needed) { // Clone the current position PointValues.PointTree currentPos = pointTree.clone(); // Process left child - intersectLeft(visitor, pointTree, docCount); + intersectLeft(visitor, pointTree, docCount, bkdState); // If we've hit the limit, the state is already saved in the recursive call if (docCount[0] >= size) { @@ -335,7 +338,7 @@ public void intersectLeft(PointValues.IntersectVisitor visitor, PointValues.Poin } } else { // Process only left child - intersectLeft(visitor, pointTree, docCount); + intersectLeft(visitor, pointTree, docCount, bkdState); } pointTree.moveToParent(); return; @@ -351,7 +354,7 @@ public void intersectLeft(PointValues.IntersectVisitor visitor, PointValues.Poin } // Process left child first - intersectLeft(visitor, pointTree, docCount); + intersectLeft(visitor, pointTree, docCount, bkdState); // If we've hit the limit, return (state is already saved in the recursive call) if (docCount[0] >= size) { @@ -360,21 +363,21 @@ public void intersectLeft(PointValues.IntersectVisitor visitor, PointValues.Poin // Process right child if needed if (rightChild != null && docCount[0] < size) { - intersectLeft(visitor, rightChild, docCount); + intersectLeft(visitor, rightChild, docCount, bkdState); } pointTree.moveToParent(); } // custom intersect visitor to walk the right of tree (from rightmost leaf going left) - public void intersectRight(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree, long[] docCount) + public void intersectRight(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree, long[] docCount, ResumableDISI.BKDState bkdState) throws IOException { // Check if we've already collected enough documents if (docCount[0] >= size) { // If we have state, save the current tree as the next node to visit - if (state != null) { - state.setCurrentTree(pointTree); - state.setInProgress(true); + if (bkdState != null) { + bkdState.setCurrentTree(pointTree); + bkdState.setInProgress(true); } return; } @@ -388,15 +391,15 @@ public void intersectRight(PointValues.IntersectVisitor visitor, PointValues.Poi if (pointTree.moveToChild() == false) { if (r == PointValues.Relation.CELL_INSIDE_QUERY) { // Save state before visiting docs if we're close to the limit - if (state != null && pointTree.size() + docCount[0] >= size) { + if (bkdState != null && pointTree.size() + docCount[0] >= size) { // Clone the tree before visiting docs PointValues.PointTree nextNode = pointTree.clone(); pointTree.visitDocIDs(visitor); // If we've hit the limit, save the next node if (docCount[0] >= size) { - state.setCurrentTree(nextNode); - state.setInProgress(true); + bkdState.setCurrentTree(nextNode); + bkdState.setInProgress(true); return; } } else { @@ -405,15 +408,15 @@ public void intersectRight(PointValues.IntersectVisitor visitor, PointValues.Poi } else { // CELL_CROSSES_QUERY // Save state before visiting docs if we're close to the limit - if (state != null && pointTree.size() + docCount[0] >= size) { + if (bkdState != null && pointTree.size() + docCount[0] >= size) { // Clone the tree before visiting docs PointValues.PointTree nextNode = pointTree.clone(); pointTree.visitDocValues(visitor); // If we've hit the limit, save the next node if (docCount[0] >= size) { - state.setCurrentTree(nextNode); - state.setInProgress(true); + bkdState.setCurrentTree(nextNode); + bkdState.setInProgress(true); return; } } else { @@ -435,12 +438,12 @@ public void intersectRight(PointValues.IntersectVisitor visitor, PointValues.Poi long needed = size - docCount[0]; if (rightSize >= needed) { // Save state before processing right child if we're going to hit the limit - if (state != null && rightSize >= needed) { + if (bkdState != null && rightSize >= needed) { // Clone the current position PointValues.PointTree currentPos = pointTree.clone(); // Process right child - intersectRight(visitor, pointTree, docCount); + intersectRight(visitor, pointTree, docCount, bkdState); // If we've hit the limit, the state is already saved in the recursive call if (docCount[0] >= size) { @@ -448,7 +451,7 @@ public void intersectRight(PointValues.IntersectVisitor visitor, PointValues.Poi } } else { // Right child has all we need - only process right - intersectRight(visitor, pointTree, docCount); + intersectRight(visitor, pointTree, docCount, bkdState); } pointTree.moveToParent(); return; @@ -457,7 +460,7 @@ public void intersectRight(PointValues.IntersectVisitor visitor, PointValues.Poi // Process right child first (for DESC) if (hasRightChild) { - intersectRight(visitor, pointTree, docCount); + intersectRight(visitor, pointTree, docCount, bkdState); // If we've hit the limit, return (state is already saved in the recursive call) if (docCount[0] >= size) { @@ -467,23 +470,23 @@ public void intersectRight(PointValues.IntersectVisitor visitor, PointValues.Poi // Process left child if needed if (docCount[0] < size) { - intersectRight(visitor, leftChild, docCount); + intersectRight(visitor, leftChild, docCount, bkdState); } pointTree.moveToParent(); } - private void captureStateBeforeIntersect() { + private void captureStateBeforeIntersect(ResumableDISI.BKDState state) { // Save the current state before intersect - stateBeforeIntersect = state.getCurrentTree() != null ? state.getCurrentTree().clone() : null; + tempBkdState = state.getCurrentTree() != null ? state.getCurrentTree().clone() : null; } - private void updateStateAfterIntersect(long[] docCount) { + private void updateStateAfterIntersect(long[] docCount, ResumableDISI.BKDState state) { // If we've collected enough documents, we need to save state - if (docCount[0] >= size && stateBeforeIntersect != null) { + if (docCount[0] >= size && tempBkdState != null) { // We've collected enough documents, save the state from before the intersect // This is a simplification - ideally we'd save the exact point where we stopped - state.setCurrentTree(stateBeforeIntersect); + state.setCurrentTree(tempBkdState); } else { // We've exhausted the tree state.setExhausted(true); @@ -491,12 +494,12 @@ private void updateStateAfterIntersect(long[] docCount) { } // Temporary variable to hold state during intersect - private PointValues.PointTree stateBeforeIntersect; + private PointValues.PointTree tempBkdState; @Override public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { LeafReader reader = context.reader(); - long[] docCount = { 0 }; + long[] docCount = {0}; PointValues values = reader.getPointValues(pointRangeQuery.getField()); if (checkValidPointValues(values) == false) { @@ -507,114 +510,116 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti return pointRangeQueryWeight.scorerSupplier(context); } else { - if (state == null) { - state = new ResumableDISI.BKDState(); - } - // Reset docCount if we're starting fresh - if (state.getCurrentTree() == null && !state.isExhausted()) { - docCount[0] = 0; - } else { - // Resume from where we left off - docCount[0] = state.getDocCount(); - } + // Reset docCount since we're starting fresh for this shard + docCount[0] = 0; + } - if (sortOrder == null || sortOrder.equals(SortOrder.ASC)) { - return new ScorerSupplier() { + if (sortOrder == null || sortOrder.equals(SortOrder.ASC)) { + return new ScorerSupplier() { - final DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values); - final PointValues.IntersectVisitor visitor = getIntersectVisitor(result, docCount); - long cost = -1; + final DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values); + final PointValues.IntersectVisitor visitor = getIntersectVisitor(result, docCount); + long cost = -1; - @Override - public Scorer get(long leadCost) throws IOException { - // Check if we have a saved tree and we're not exhausted - if (state.getCurrentTree() == null && !state.isExhausted()) { - // First call - start from the root - state.setCurrentTree(values.getPointTree()); - } + // Create per-shard BKD state to avoid concurrency issues + ResumableDISI.BKDState shardState = new ResumableDISI.BKDState(); + + @Override + public Scorer get(long leadCost) throws IOException { + // Check if we have a saved tree and we're not exhausted + if (shardState.getCurrentTree() == null && !shardState.isExhausted()) { + // First call - start from the root + shardState.setCurrentTree(values.getPointTree()); + } else if (shardState.getCurrentTree() != null) { + // Resume from where we left off + docCount[0] = shardState.getDocCount(); + } - // Only process if we haven't collected enough documents and we're not exhausted - if (!state.isExhausted() && docCount[0] < size) { - // Call intersect with the current tree - // The state will be updated inside intersectLeft - intersectLeft(state.getCurrentTree(), visitor, docCount); + // Only process if we haven't collected enough documents and we're not exhausted + if (!shardState.isExhausted() && docCount[0] < size) { + // Call intersect with the current tree, passing the shard state + // The state will be updated inside intersectLeft + intersectLeft(shardState.getCurrentTree(), visitor, docCount, shardState); - // Update the state's docCount - state.setDocCount(docCount[0]); + // Update the state's docCount + shardState.setDocCount(docCount[0]); - // If we didn't collect enough documents and we're not in progress, we've exhausted the tree - if (docCount[0] < size && !state.isInProgress()) { - state.setExhausted(true); - } + // If we didn't collect enough documents and we're not in progress, we've exhausted the tree + if (docCount[0] < size && !shardState.isInProgress()) { + shardState.setExhausted(true); } + } + + DocIdSetIterator iterator = result.build().iterator(); + return new ConstantScoreScorer(score(), scoreMode, iterator); + } - DocIdSetIterator iterator = result.build().iterator(); - return new ConstantScoreScorer(score(), scoreMode, iterator); + @Override + public long cost() { + if (cost == -1) { + // Computing the cost may be expensive, so only do it if necessary + cost = values.estimateDocCount(visitor); + assert cost >= 0; } + return cost; + } + }; + } else { + // we need to fetch size + deleted docs since the collector will prune away deleted docs resulting in fewer results + // than expected + final int deletedDocs = reader.numDeletedDocs(); + size += deletedDocs; + return new ScorerSupplier() { - @Override - public long cost() { - if (cost == -1) { - // Computing the cost may be expensive, so only do it if necessary - cost = values.estimateDocCount(visitor); - assert cost >= 0; - } - return cost; + final DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values); + final PointValues.IntersectVisitor visitor = getIntersectVisitor(result, docCount); + long cost = -1; + + // Create per-shard BKD state to avoid concurrency issues + ResumableDISI.BKDState state = new ResumableDISI.BKDState(); + + @Override + public Scorer get(long leadCost) throws IOException { + // Check if we have a saved tree and we're not exhausted + if (state.getCurrentTree() == null && !state.isExhausted()) { + // First call - start from the root + state.setCurrentTree(values.getPointTree()); } - }; - } else { - // we need to fetch size + deleted docs since the collector will prune away deleted docs resulting in fewer results - // than expected - final int deletedDocs = reader.numDeletedDocs(); - size += deletedDocs; - return new ScorerSupplier() { - - final DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values); - final PointValues.IntersectVisitor visitor = getIntersectVisitor(result, docCount); - long cost = -1; - - @Override - public Scorer get(long leadCost) throws IOException { - // Check if we have a saved tree and we're not exhausted - if (state.getCurrentTree() == null && !state.isExhausted()) { - // First call - start from the root - state.setCurrentTree(values.getPointTree()); - } - // Only process if we haven't collected enough documents and we're not exhausted - if (!state.isExhausted() && docCount[0] < size) { - // Call intersect with the current tree - // The state will be updated inside intersectRight - intersectRight(state.getCurrentTree(), visitor, docCount); + // Only process if we haven't collected enough documents and we're not exhausted + if (!state.isExhausted() && docCount[0] < size) { + // Call intersect with the current tree + // The state will be updated inside intersectRight + intersectRight(state.getCurrentTree(), visitor, docCount, state); - // Update the state's docCount - state.setDocCount(docCount[0]); + // Update the state's docCount + state.setDocCount(docCount[0]); - // If we didn't collect enough documents and we're not in progress, we've exhausted the tree - if (docCount[0] < size && !state.isInProgress()) { - state.setExhausted(true); - } + // If we didn't collect enough documents and we're not in progress, we've exhausted the tree + if (docCount[0] < size && !state.isInProgress()) { + state.setExhausted(true); } - - DocIdSetIterator iterator = result.build().iterator(); - return new ConstantScoreScorer(score(), scoreMode, iterator); } - @Override - public long cost() { - if (cost == -1) { - // Computing the cost may be expensive, so only do it if necessary - cost = values.estimateDocCount(visitor); - assert cost >= 0; - } - return cost; + DocIdSetIterator iterator = result.build().iterator(); + return new ConstantScoreScorer(score(), scoreMode, iterator); + } + + @Override + public long cost() { + if (cost == -1) { + // Computing the cost may be expensive, so only do it if necessary + cost = values.estimateDocCount(visitor); + assert cost >= 0; } - }; - } + return cost; + } + }; } } + @Override public int count(LeafReaderContext context) throws IOException { return pointRangeQueryWeight.count(context); From 2f44e3a7dec1e118d0753bef548d72e9555ec5a5 Mon Sep 17 00:00:00 2001 From: Sawan Srivastava Date: Tue, 29 Jul 2025 17:45:47 +0000 Subject: [PATCH 17/38] use Lucene's ConjunctionDISI Signed-off-by: Sawan Srivastava --- .../ApproximateBooleanScorerSupplier.java | 196 ++++++++-------- .../ApproximateConjunctionDISI.java | 5 +- .../ApproximatePointRangeQuery.java | 36 ++- .../approximate/ApproximateScoreQuery.java | 6 + .../search/approximate/ResumableDISI.java | 221 ++++++++++-------- 5 files changed, 256 insertions(+), 208 deletions(-) diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java index 119e86f136cae..3817d8bb4d0d9 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java @@ -1,15 +1,8 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - package org.opensearch.search.approximate; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.BulkScorer; +import org.apache.lucene.search.ConjunctionUtils; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.LeafCollector; import org.apache.lucene.search.Query; @@ -25,7 +18,7 @@ /** * A ScorerSupplier implementation for ApproximateBooleanQuery that creates resumable DocIdSetIterators - * for each clause and coordinates their usage in the boolean query context. + * for each clause and uses Lucene's ConjunctionUtils to coordinate them. */ public class ApproximateBooleanScorerSupplier extends ScorerSupplier { private final List clauseWeights; @@ -95,8 +88,31 @@ public Scorer get(long leadCost) throws IOException { } } - // Create an ApproximateConjunctionScorer with the clause iterators - return new ApproximateConjunctionScorer(boost, scoreMode, clauseIterators); + // Use Lucene's ConjunctionUtils to create the conjunction + DocIdSetIterator conjunctionDISI = ConjunctionUtils.intersectIterators(clauseIterators); + + // Create a simple scorer that wraps the conjunction + return new Scorer() { + @Override + public DocIdSetIterator iterator() { + return conjunctionDISI; + } + + @Override + public float score() throws IOException { + return boost; + } + + @Override + public float getMaxScore(int upTo) throws IOException { + return boost; + } + + @Override + public int docID() { + return conjunctionDISI.docID(); + } + }; } /** @@ -125,8 +141,73 @@ public BulkScorer bulkScorer() throws IOException { } } - // Create an ApproximateBooleanBulkScorer with the clause iterators - return new ApproximateBooleanBulkScorer(clauseIterators, threshold); + // Use Lucene's ConjunctionUtils to create the conjunction + DocIdSetIterator conjunctionDISI = ConjunctionUtils.intersectIterators(clauseIterators); + + // Create a simple bulk scorer that wraps the conjunction + return new BulkScorer() { + @Override + public int score(LeafCollector collector, Bits acceptDocs, int min, int max) throws IOException { + // Create a simple scorer for the collector + Scorer scorer = new Scorer() { + @Override + public DocIdSetIterator iterator() { + return conjunctionDISI; + } + + @Override + public float score() throws IOException { + return boost; + } + + @Override + public float getMaxScore(int upTo) throws IOException { + return boost; + } + + @Override + public int docID() { + return conjunctionDISI.docID(); + } + }; + + collector.setScorer(scorer); + + // Track how many documents we've collected + int collected = 0; + int docID; + + // Continue collecting until we reach the threshold + while (collected < threshold) { + // Get the next document from the conjunction + docID = conjunctionDISI.nextDoc(); + + if (docID == DocIdSetIterator.NO_MORE_DOCS) { + // No more documents - ResumableDISIs will expand internally if possible + break; + } + + if (docID >= max) { + // We've reached the end of the range + return docID; + } + + if (docID >= min && (acceptDocs == null || acceptDocs.get(docID))) { + // Collect the document + collector.collect(docID); + collected++; + } + } + + // We've either collected enough documents or exhausted all possibilities + return DocIdSetIterator.NO_MORE_DOCS; + } + + @Override + public long cost() { + return ApproximateBooleanScorerSupplier.this.cost(); + } + }; } /** @@ -155,93 +236,4 @@ public long cost() { } return cost; } - - /** - * A BulkScorer implementation that coordinates multiple DocIdSetIterators (including ResumableDISIs) to implement - * the circular scoring process described in the blog. - */ - private static class ApproximateBooleanBulkScorer extends BulkScorer { - private final List clauseIterators; - private final int threshold; - - public ApproximateBooleanBulkScorer(List clauseIterators, int threshold) { - this.clauseIterators = clauseIterators; - this.threshold = threshold; - } - - @Override - public int score(LeafCollector collector, Bits acceptDocs, int min, int max) throws IOException { - // Create an ApproximateConjunctionDISI once and reuse it (preserve conjunction state) - ApproximateConjunctionDISI conjunctionDISI = new ApproximateConjunctionDISI(clauseIterators); - - // Create a scorer for the collector that reuses the same conjunction - ApproximateConjunctionScorer scorer = new ApproximateConjunctionScorer(0.0f, ScoreMode.COMPLETE, conjunctionDISI); - - // Set the scorer on the collector - collector.setScorer(scorer); - - // Track how many documents we've collected - int collected = 0; - int docID; - - // Continue collecting until we reach the threshold - while (collected < threshold) { - // Get the next document from the conjunction - docID = conjunctionDISI.nextDoc(); - - if (docID == DocIdSetIterator.NO_MORE_DOCS) { - // No more documents in current state - try to expand ResumableDISIs - boolean anyExpanded = false; - for (DocIdSetIterator disi : clauseIterators) { - if (disi instanceof ResumableDISI) { - ResumableDISI resumableDISI = (ResumableDISI) disi; - if (!resumableDISI.isExhausted()) { - resumableDISI.resetForNextBatch(); // This expands the document set - anyExpanded = true; - } - } - } - - // If no ResumableDISIs were expanded, we're truly done - if (!anyExpanded) { - break; - } - - // Reset the conjunction so it can continue with expanded iterators -// conjunctionDISI.resetAfterExpansion(); - - // After expansion, check if we're already positioned on a valid document - docID = conjunctionDISI.docID(); - if (docID == DocIdSetIterator.NO_MORE_DOCS) { - // Still no document, try nextDoc() in the next iteration - continue; - } - // Fall through to process the current document - } - - if (docID >= max) { - // We've reached the end of the range - return docID; - } - - if (docID >= min && (acceptDocs == null || acceptDocs.get(docID))) { - // Collect the document - collector.collect(docID); - collected++; - } - } - - System.out.println(collected); - // We've either collected enough documents or exhausted all possibilities - return DocIdSetIterator.NO_MORE_DOCS; - } - - /** - * Same as {@link DocIdSetIterator#cost()} for bulk scorers. - */ - @Override - public long cost() { - return 0; - } - } } diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionDISI.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionDISI.java index 8bc95f4d8ed83..a7712838e12ae 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionDISI.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionDISI.java @@ -57,8 +57,7 @@ public int advance(int target) throws IOException { * with resumable expansion support. */ private int doNext(int doc) throws IOException { - advanceHead: - for (; ; ) { + advanceHead: for (;;) { // Handle NO_MORE_DOCS with resumable expansion if (doc == NO_MORE_DOCS) { if (tryExpandResumableDISIs()) { @@ -116,7 +115,7 @@ private boolean tryExpandResumableDISIs() throws IOException { if (iterator instanceof ResumableDISI) { ResumableDISI resumable = (ResumableDISI) iterator; if (!resumable.isExhausted()) { - resumable.resetForNextBatch(); + // resumable.resetForNextBatch(); anyExpanded = true; } } diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java index 65b416315a051..4daa6c51b45f2 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java @@ -249,21 +249,33 @@ private boolean checkValidPointValues(PointValues values) throws IOException { return true; } - private void intersectLeft(PointValues.PointTree pointTree, PointValues.IntersectVisitor visitor, long[] docCount, ResumableDISI.BKDState bkdState) - throws IOException { + private void intersectLeft( + PointValues.PointTree pointTree, + PointValues.IntersectVisitor visitor, + long[] docCount, + ResumableDISI.BKDState bkdState + ) throws IOException { intersectLeft(visitor, pointTree, docCount, bkdState); assert pointTree.moveToParent() == false; } - private void intersectRight(PointValues.PointTree pointTree, PointValues.IntersectVisitor visitor, long[] docCount, ResumableDISI.BKDState bkdState) - throws IOException { + private void intersectRight( + PointValues.PointTree pointTree, + PointValues.IntersectVisitor visitor, + long[] docCount, + ResumableDISI.BKDState bkdState + ) throws IOException { intersectRight(visitor, pointTree, docCount, bkdState); assert pointTree.moveToParent() == false; } // custom intersect visitor to walk the left of the tree - public void intersectLeft(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree, long[] docCount, ResumableDISI.BKDState bkdState) - throws IOException { + public void intersectLeft( + PointValues.IntersectVisitor visitor, + PointValues.PointTree pointTree, + long[] docCount, + ResumableDISI.BKDState bkdState + ) throws IOException { // Check if we've already collected enough documents if (docCount[0] >= size) { // If we have state, save the current tree as the next node to visit @@ -370,8 +382,12 @@ public void intersectLeft(PointValues.IntersectVisitor visitor, PointValues.Poin } // custom intersect visitor to walk the right of tree (from rightmost leaf going left) - public void intersectRight(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree, long[] docCount, ResumableDISI.BKDState bkdState) - throws IOException { + public void intersectRight( + PointValues.IntersectVisitor visitor, + PointValues.PointTree pointTree, + long[] docCount, + ResumableDISI.BKDState bkdState + ) throws IOException { // Check if we've already collected enough documents if (docCount[0] >= size) { // If we have state, save the current tree as the next node to visit @@ -499,7 +515,7 @@ private void updateStateAfterIntersect(long[] docCount, ResumableDISI.BKDState s @Override public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { LeafReader reader = context.reader(); - long[] docCount = {0}; + long[] docCount = { 0 }; PointValues values = reader.getPointValues(pointRangeQuery.getField()); if (checkValidPointValues(values) == false) { @@ -510,7 +526,6 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti return pointRangeQueryWeight.scorerSupplier(context); } else { - // Reset docCount since we're starting fresh for this shard docCount[0] = 0; } @@ -619,7 +634,6 @@ public long cost() { } } - @Override public int count(LeafReaderContext context) throws IOException { return pointRangeQueryWeight.count(context); diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java index 352cd3408aa67..bdb10e8967192 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java @@ -69,6 +69,12 @@ public void setContext(SearchContext context) { if (resolvedQuery instanceof ApproximateScoreQuery appxResolved) { appxResolved.setContext(context); } + } else { + for (BooleanClause boolClause : appxBool.boolQuery.clauses()) { + if (boolClause.query() instanceof ApproximateScoreQuery apprxQuery) { + apprxQuery.setContext(context); + } + } } needsRewrite = true; } else if (resolvedQuery instanceof BooleanQuery) { diff --git a/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java b/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java index 15c9c03a58aae..3fda22516271d 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java +++ b/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java @@ -16,158 +16,195 @@ import java.io.IOException; /** - * A resumable DocIdSetIterator that can be used to score documents in batches. - * This class wraps a ScorerSupplier and creates a new Scorer/DocIdSetIterator only when needed. - * It maintains state between calls to enable resuming from where it left off. + * A resumable DocIdSetIterator that internally expands when it reaches NO_MORE_DOCS. + * On the surface, this behaves identically to a regular DISI, but internally it can + * expand by scoring additional documents when needed. * - * This implementation is specifically designed for the approximation framework to enable - * early termination while preserving state between scoring cycles. + * The expansion is completely internal - external callers see a normal DISI interface + * that continues to return documents even after initially hitting NO_MORE_DOCS. */ public class ResumableDISI extends DocIdSetIterator { - private static final int DEFAULT_BATCH_SIZE = 10_000; + private static final int DEFAULT_EXPANSION_SIZE = 10_000; private final ScorerSupplier scorerSupplier; + private final int expansionSize; + + // Current state private DocIdSetIterator currentDisi; - private final int batchSize; - private boolean exhausted = false; + private int currentDocId = -1; + private boolean fullyExhausted = false; - // State tracking - track batches, not individual document movements - private int lastDocID = -1; - private int batchCount = 0; // How many batches we've created - private boolean needsNewBatch = true; // Whether we need to create a new batch + // BKD traversal state for approximatable queries + private BKDState bkdState; + private int documentsScored = 0; // Total documents scored across all expansions /** - * Creates a new ResumableDISI with the default batch size of 10,000 documents. + * Creates a new ResumableDISI with the default expansion size of 10,000 documents. * * @param scorerSupplier The scorer supplier to get scorers from */ public ResumableDISI(ScorerSupplier scorerSupplier) { - this(scorerSupplier, DEFAULT_BATCH_SIZE); + this(scorerSupplier, DEFAULT_EXPANSION_SIZE); } /** - * Creates a new ResumableDISI with the specified batch size. + * Creates a new ResumableDISI with the specified expansion size. * * @param scorerSupplier The scorer supplier to get scorers from - * @param batchSize The number of documents to score in each batch + * @param expansionSize The number of documents to score in each expansion */ - public ResumableDISI(ScorerSupplier scorerSupplier, int batchSize) { + public ResumableDISI(ScorerSupplier scorerSupplier, int expansionSize) { this.scorerSupplier = scorerSupplier; - this.batchSize = batchSize; - } - - /** - * Initializes or resets the internal DocIdSetIterator. - * For approximatable queries, this leverages their existing resumable mechanism. - * For non-approximatable queries, this creates new scorers as needed. - * - * @return The current DocIdSetIterator - * @throws IOException If there's an error getting the scorer - */ - private DocIdSetIterator getOrCreateDisi() throws IOException { - if (exhausted) { - return currentDisi; // Already exhausted, no need to create a new one - } - - if (currentDisi == null || needsNewBatch) { - // Get a new scorer and its iterator - // For approximatable queries, the scorer supplier will handle resumable state internally - Scorer scorer = scorerSupplier.get(scorerSupplier.cost()); - currentDisi = scorer.iterator(); - - // For non-approximatable queries, we need to advance past the last document - // For approximatable queries, they handle this internally via their BKD state - if (lastDocID >= 0) { - // Check if we need to advance (for non-approximatable queries) - if (currentDisi.docID() <= lastDocID) { - currentDisi.advance(lastDocID + 1); - } - } - - // Mark that we've created a new batch - batchCount++; - needsNewBatch = false; - } - - return currentDisi; + this.expansionSize = expansionSize; + this.bkdState = new BKDState(); } @Override public int docID() { - if (currentDisi == null) { - return -1; - } - return currentDisi.docID(); + return currentDocId; } @Override public int nextDoc() throws IOException { - DocIdSetIterator disi = getOrCreateDisi(); - int doc = disi.nextDoc(); + if (fullyExhausted) { + return NO_MORE_DOCS; + } + + // If we don't have a current iterator, get one + if (currentDisi == null) { + if (!expandInternally()) { + return NO_MORE_DOCS; + } + // expandInternally() already positioned us on the first document + return currentDocId; + } + + // Try to get the next document from current iterator + int doc = currentDisi.nextDoc(); if (doc != NO_MORE_DOCS) { - lastDocID = doc; - } else { - exhausted = true; + currentDocId = doc; + return doc; + } + + // Current iterator exhausted, try to expand internally + if (expandInternally()) { + // expandInternally() already positioned us on the first document of the new batch + return currentDocId; } - return doc; + // No more expansion possible + currentDocId = NO_MORE_DOCS; + return NO_MORE_DOCS; } @Override public int advance(int target) throws IOException { - DocIdSetIterator disi = getOrCreateDisi(); - int doc = disi.advance(target); + if (fullyExhausted) { + return NO_MORE_DOCS; + } - if (doc != NO_MORE_DOCS) { - lastDocID = doc; + // If we don't have a current iterator, get one + if (currentDisi == null) { + if (!expandInternally()) { + return NO_MORE_DOCS; + } + // If the first document is >= target, we're good + if (currentDocId >= target) { + return currentDocId; + } + // Otherwise, advance to target + int doc = currentDisi.advance(target); + if (doc != NO_MORE_DOCS) { + currentDocId = doc; + return doc; + } + // Fall through to try expansion } else { - exhausted = true; + // Try to advance current iterator + int doc = currentDisi.advance(target); + if (doc != NO_MORE_DOCS) { + currentDocId = doc; + return doc; + } + // Current iterator exhausted, try to expand } - return doc; - } + // Current iterator exhausted, try to expand internally + if (expandInternally()) { + // If the first document of new batch is >= target, we're good + if (currentDocId >= target) { + return currentDocId; + } + // Otherwise, advance to target + int doc = currentDisi.advance(target); + if (doc != NO_MORE_DOCS) { + currentDocId = doc; + return doc; + } + } - @Override - public long cost() { - return scorerSupplier.cost(); + // No more expansion possible + currentDocId = NO_MORE_DOCS; + fullyExhausted = true; + return NO_MORE_DOCS; } /** - * Resets the iterator to start a new batch from the last document ID. - * This allows the caller to continue scoring from where it left off. + * Expands the iterator internally by getting a new scorer from the supplier. + * This is called when we hit NO_MORE_DOCS but more documents might be available. + * + * @return true if expansion was successful, false if fully exhausted + * @throws IOException If there's an error getting the scorer */ - public void resetForNextBatch() { - if (!exhausted) { - needsNewBatch = true; // Mark that we need a new batch on next access + private boolean expandInternally() throws IOException { + if (fullyExhausted) { + return false; } + + // Get a new scorer from the supplier - this will resume from saved BKD state + Scorer scorer = scorerSupplier.get(scorerSupplier.cost()); + if (scorer == null) { + fullyExhausted = true; + return false; + } + + currentDisi = scorer.iterator(); + documentsScored += expansionSize; // Track total documents scored + + // Check if the new iterator has any documents + int firstDoc = currentDisi.nextDoc(); + if (firstDoc == NO_MORE_DOCS) { + fullyExhausted = true; + return false; + } + + // Position the iterator on the first document + currentDocId = firstDoc; + return true; } - /** - * Returns the number of batches created so far. - * - * @return The number of batches created - */ - public int getBatchCount() { - return batchCount; + @Override + public long cost() { + return scorerSupplier.cost(); } /** - * Returns whether this iterator has been exhausted. + * Returns whether this iterator has been fully exhausted. * * @return true if there are no more documents to score */ public boolean isExhausted() { - return exhausted; + return fullyExhausted; } /** - * Returns the last document ID that was scored. + * Returns the total number of documents scored across all expansions. * - * @return The last document ID, or -1 if no documents have been scored + * @return The total number of documents scored */ - public int getLastDocID() { - return lastDocID; + public int getDocumentsScored() { + return documentsScored; } /** From 349a9f254877ee98ed9673939c068d1e539290b6 Mon Sep 17 00:00:00 2001 From: Sawan Srivastava Date: Tue, 29 Jul 2025 11:51:55 -0700 Subject: [PATCH 18/38] update state management in ApproximatePointRangeQuery Signed-off-by: Sawan Srivastava --- .../ApproximatePointRangeQuery.java | 232 +++++++++--------- 1 file changed, 111 insertions(+), 121 deletions(-) diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java index 4daa6c51b45f2..5adaac357fb28 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java @@ -278,7 +278,7 @@ public void intersectLeft( ) throws IOException { // Check if we've already collected enough documents if (docCount[0] >= size) { - // If we have state, save the current tree as the next node to visit + // Save current position for resumption if (bkdState != null) { bkdState.setCurrentTree(pointTree); bkdState.setInProgress(true); @@ -293,38 +293,25 @@ public void intersectLeft( // Handle leaf nodes if (pointTree.moveToChild() == false) { + long docsBefore = docCount[0]; + if (r == PointValues.Relation.CELL_INSIDE_QUERY) { - // Save state before visiting docs if we're close to the limit - if (bkdState != null && pointTree.size() + docCount[0] >= size) { - // Clone the tree before visiting docs - PointValues.PointTree nextNode = pointTree.clone(); - pointTree.visitDocIDs(visitor); - - // If we've hit the limit, save the next node - if (docCount[0] >= size) { - bkdState.setCurrentTree(nextNode); - bkdState.setInProgress(true); - return; - } - } else { - pointTree.visitDocIDs(visitor); - } + pointTree.visitDocIDs(visitor); } else { // CELL_CROSSES_QUERY - // Save state before visiting docs if we're close to the limit - if (bkdState != null && pointTree.size() + docCount[0] >= size) { - // Clone the tree before visiting docs - PointValues.PointTree nextNode = pointTree.clone(); - pointTree.visitDocValues(visitor); - - // If we've hit the limit, save the next node - if (docCount[0] >= size) { - bkdState.setCurrentTree(nextNode); - bkdState.setInProgress(true); - return; - } + pointTree.visitDocValues(visitor); + } + + // After visiting docs, check if we hit the limit + if (docCount[0] >= size && bkdState != null) { + // We've processed this leaf and hit the limit + // Find the next unvisited position by moving up the tree + PointValues.PointTree nextPosition = findNextUnvisitedPosition(pointTree); + if (nextPosition != null) { + bkdState.setCurrentTree(nextPosition); + bkdState.setInProgress(true); } else { - pointTree.visitDocValues(visitor); + bkdState.setExhausted(true); } } return; @@ -336,28 +323,14 @@ public void intersectLeft( long needed = size - docCount[0]; if (leftSize >= needed) { - // Save state before processing left child if we're going to hit the limit - if (bkdState != null && leftSize >= needed) { - // Clone the current position - PointValues.PointTree currentPos = pointTree.clone(); - - // Process left child - intersectLeft(visitor, pointTree, docCount, bkdState); - - // If we've hit the limit, the state is already saved in the recursive call - if (docCount[0] >= size) { - return; - } - } else { - // Process only left child - intersectLeft(visitor, pointTree, docCount, bkdState); - } + // Process only left child + intersectLeft(visitor, pointTree, docCount, bkdState); pointTree.moveToParent(); return; } } - // We need both children - now clone right + // We need both children - clone right child before processing left PointValues.PointTree rightChild = null; if (pointTree.moveToSibling()) { rightChild = pointTree.clone(); @@ -368,12 +341,18 @@ public void intersectLeft( // Process left child first intersectLeft(visitor, pointTree, docCount, bkdState); - // If we've hit the limit, return (state is already saved in the recursive call) + // If we've hit the limit during left processing, check if right child should be saved if (docCount[0] >= size) { + if (rightChild != null && bkdState != null && !bkdState.isInProgress()) { + // Left child processing completed but didn't set next position + // Right child is the next unvisited position + bkdState.setCurrentTree(rightChild); + bkdState.setInProgress(true); + } return; } - // Process right child if needed + // Process right child if needed and available if (rightChild != null && docCount[0] < size) { intersectLeft(visitor, rightChild, docCount, bkdState); } @@ -381,6 +360,22 @@ public void intersectLeft( pointTree.moveToParent(); } + // Helper method to find the next unvisited position after processing a leaf + private PointValues.PointTree findNextUnvisitedPosition(PointValues.PointTree currentLeaf) throws IOException { + PointValues.PointTree tree = currentLeaf.clone(); + + // Move up the tree to find the next unvisited sibling or ancestor's sibling + while (tree.moveToParent()) { + // Try to move to sibling (next unvisited subtree) + if (tree.moveToSibling()) { + return tree.clone(); + } + } + + // No more unvisited positions found + return null; + } + // custom intersect visitor to walk the right of tree (from rightmost leaf going left) public void intersectRight( PointValues.IntersectVisitor visitor, @@ -390,7 +385,7 @@ public void intersectRight( ) throws IOException { // Check if we've already collected enough documents if (docCount[0] >= size) { - // If we have state, save the current tree as the next node to visit + // Save current position for resumption if (bkdState != null) { bkdState.setCurrentTree(pointTree); bkdState.setInProgress(true); @@ -405,38 +400,25 @@ public void intersectRight( // Handle leaf nodes if (pointTree.moveToChild() == false) { + long docsBefore = docCount[0]; + if (r == PointValues.Relation.CELL_INSIDE_QUERY) { - // Save state before visiting docs if we're close to the limit - if (bkdState != null && pointTree.size() + docCount[0] >= size) { - // Clone the tree before visiting docs - PointValues.PointTree nextNode = pointTree.clone(); - pointTree.visitDocIDs(visitor); - - // If we've hit the limit, save the next node - if (docCount[0] >= size) { - bkdState.setCurrentTree(nextNode); - bkdState.setInProgress(true); - return; - } - } else { - pointTree.visitDocIDs(visitor); - } + pointTree.visitDocIDs(visitor); } else { // CELL_CROSSES_QUERY - // Save state before visiting docs if we're close to the limit - if (bkdState != null && pointTree.size() + docCount[0] >= size) { - // Clone the tree before visiting docs - PointValues.PointTree nextNode = pointTree.clone(); - pointTree.visitDocValues(visitor); - - // If we've hit the limit, save the next node - if (docCount[0] >= size) { - bkdState.setCurrentTree(nextNode); - bkdState.setInProgress(true); - return; - } + pointTree.visitDocValues(visitor); + } + + // After visiting docs, check if we hit the limit + if (docCount[0] >= size && bkdState != null) { + // We've processed this leaf and hit the limit + // Find the next unvisited position by moving up the tree (for right traversal) + PointValues.PointTree nextPosition = findNextUnvisitedPositionRight(pointTree); + if (nextPosition != null) { + bkdState.setCurrentTree(nextPosition); + bkdState.setInProgress(true); } else { - pointTree.visitDocValues(visitor); + bkdState.setExhausted(true); } } return; @@ -453,33 +435,25 @@ public void intersectRight( long rightSize = pointTree.size(); long needed = size - docCount[0]; if (rightSize >= needed) { - // Save state before processing right child if we're going to hit the limit - if (bkdState != null && rightSize >= needed) { - // Clone the current position - PointValues.PointTree currentPos = pointTree.clone(); - - // Process right child - intersectRight(visitor, pointTree, docCount, bkdState); - - // If we've hit the limit, the state is already saved in the recursive call - if (docCount[0] >= size) { - return; - } - } else { - // Right child has all we need - only process right - intersectRight(visitor, pointTree, docCount, bkdState); - } + // Right child has all we need - only process right + intersectRight(visitor, pointTree, docCount, bkdState); pointTree.moveToParent(); return; } } - // Process right child first (for DESC) + // Process right child first (for DESC order) if (hasRightChild) { intersectRight(visitor, pointTree, docCount, bkdState); - // If we've hit the limit, return (state is already saved in the recursive call) + // If we've hit the limit during right processing, check if left child should be saved if (docCount[0] >= size) { + if (bkdState != null && !bkdState.isInProgress()) { + // Right child processing completed but didn't set next position + // Left child is the next unvisited position + bkdState.setCurrentTree(leftChild); + bkdState.setInProgress(true); + } return; } } @@ -492,25 +466,36 @@ public void intersectRight( pointTree.moveToParent(); } - private void captureStateBeforeIntersect(ResumableDISI.BKDState state) { - // Save the current state before intersect - tempBkdState = state.getCurrentTree() != null ? state.getCurrentTree().clone() : null; - } - - private void updateStateAfterIntersect(long[] docCount, ResumableDISI.BKDState state) { - // If we've collected enough documents, we need to save state - if (docCount[0] >= size && tempBkdState != null) { - // We've collected enough documents, save the state from before the intersect - // This is a simplification - ideally we'd save the exact point where we stopped - state.setCurrentTree(tempBkdState); - } else { - // We've exhausted the tree - state.setExhausted(true); + // Helper method to find the next unvisited position for right traversal + private PointValues.PointTree findNextUnvisitedPositionRight(PointValues.PointTree currentLeaf) throws IOException { + PointValues.PointTree tree = currentLeaf.clone(); + + // For right traversal, we need to find the next position going from right to left + // Move up the tree to find the next unvisited left sibling or ancestor's left sibling + while (tree.moveToParent()) { + // Check if we came from the right child + PointValues.PointTree parent = tree.clone(); + if (parent.moveToChild()) { + // We're at left child, check if there's a left sibling to process + // For right traversal, after processing right subtree, we process left subtree + PointValues.PointTree leftSibling = parent.clone(); + if (!isCurrentPosition(leftSibling, tree)) { + // This left child hasn't been processed yet + return leftSibling; + } + } } + + // No more unvisited positions found + return null; } - // Temporary variable to hold state during intersect - private PointValues.PointTree tempBkdState; + // Helper to check if two tree positions are the same + private boolean isCurrentPosition(PointValues.PointTree tree1, PointValues.PointTree tree2) { + // Simple comparison - in a real implementation, you'd compare the actual tree positions + // For now, we'll use a conservative approach + return false; // Always assume different positions to be safe + } @Override public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { @@ -524,10 +509,6 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti // values.size(): total points indexed, In most cases: values.size() ≈ number of documents (assuming single-valued fields) if (size > values.size()) { return pointRangeQueryWeight.scorerSupplier(context); - } else { - - // Reset docCount since we're starting fresh for this shard - docCount[0] = 0; } if (sortOrder == null || sortOrder.equals(SortOrder.ASC)) { @@ -546,15 +527,18 @@ public Scorer get(long leadCost) throws IOException { if (shardState.getCurrentTree() == null && !shardState.isExhausted()) { // First call - start from the root shardState.setCurrentTree(values.getPointTree()); - } else if (shardState.getCurrentTree() != null) { + docCount[0] = 0; // Reset doc count for first call + } else if (shardState.getCurrentTree() != null && !shardState.isExhausted()) { // Resume from where we left off - docCount[0] = shardState.getDocCount(); + docCount[0] = (int) shardState.getDocCount(); } // Only process if we haven't collected enough documents and we're not exhausted - if (!shardState.isExhausted() && docCount[0] < size) { + if (!shardState.isExhausted() && docCount[0] < size && shardState.getCurrentTree() != null) { + // Reset the in-progress flag before processing + shardState.setInProgress(false); + // Call intersect with the current tree, passing the shard state - // The state will be updated inside intersectLeft intersectLeft(shardState.getCurrentTree(), visitor, docCount, shardState); // Update the state's docCount @@ -600,12 +584,18 @@ public Scorer get(long leadCost) throws IOException { if (state.getCurrentTree() == null && !state.isExhausted()) { // First call - start from the root state.setCurrentTree(values.getPointTree()); + docCount[0] = 0; // Reset doc count for first call + } else if (state.getCurrentTree() != null && !state.isExhausted()) { + // Resume from where we left off + docCount[0] = (int) state.getDocCount(); } // Only process if we haven't collected enough documents and we're not exhausted - if (!state.isExhausted() && docCount[0] < size) { + if (!state.isExhausted() && docCount[0] < size && state.getCurrentTree() != null) { + // Reset the in-progress flag before processing + state.setInProgress(false); + // Call intersect with the current tree - // The state will be updated inside intersectRight intersectRight(state.getCurrentTree(), visitor, docCount, state); // Update the state's docCount From 6d51f1eb2df4b99cd42d9bf72fc62da63e76ee94 Mon Sep 17 00:00:00 2001 From: Sawan Srivastava Date: Thu, 31 Jul 2025 18:51:24 +0000 Subject: [PATCH 19/38] working ApproximatePointRangeQuery Signed-off-by: Sawan Srivastava --- .../ApproximatePointRangeQuery.java | 329 ++++++------------ 1 file changed, 105 insertions(+), 224 deletions(-) diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java index 5adaac357fb28..7e1d3fb45e0ad 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java @@ -30,6 +30,7 @@ import org.apache.lucene.search.Weight; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.DocIdSetBuilder; +import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.IntsRef; import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.mapper.NumericPointEncoder; @@ -58,10 +59,6 @@ public class ApproximatePointRangeQuery extends ApproximateQuery { public PointRangeQuery pointRangeQuery; private final Function valueToString; - // Remove shared state field to avoid concurrency issues - // BKD state is now passed as parameter to intersect methods - // private ResumableDISI.BKDState state; - public ApproximatePointRangeQuery( String field, byte[] lowerPoint, @@ -108,11 +105,6 @@ public void setSortOrder(SortOrder sortOrder) { this.sortOrder = sortOrder; } - public void setBKDState(ResumableDISI.BKDState state) { - // This method is no longer used since state is now per-shard - // Keeping for compatibility but it's a no-op - } - @Override public Query rewrite(IndexSearcher indexSearcher) throws IOException { return super.rewrite(indexSearcher); @@ -249,34 +241,21 @@ private boolean checkValidPointValues(PointValues values) throws IOException { return true; } - private void intersectLeft( - PointValues.PointTree pointTree, - PointValues.IntersectVisitor visitor, - long[] docCount, - ResumableDISI.BKDState bkdState - ) throws IOException { - intersectLeft(visitor, pointTree, docCount, bkdState); + private void intersectLeft(PointValues.PointTree pointTree, PointValues.IntersectVisitor visitor, long[] docCount, ResumableDISI.BKDState state) + throws IOException { + intersectLeft(visitor, pointTree, docCount, state); assert pointTree.moveToParent() == false; } - private void intersectRight( - PointValues.PointTree pointTree, - PointValues.IntersectVisitor visitor, - long[] docCount, - ResumableDISI.BKDState bkdState - ) throws IOException { - intersectRight(visitor, pointTree, docCount, bkdState); + private void intersectRight(PointValues.PointTree pointTree, PointValues.IntersectVisitor visitor, long[] docCount) + throws IOException { + intersectRight(visitor, pointTree, docCount); assert pointTree.moveToParent() == false; } // custom intersect visitor to walk the left of the tree - public void intersectLeft( - PointValues.IntersectVisitor visitor, - PointValues.PointTree pointTree, - long[] docCount, - ResumableDISI.BKDState bkdState - ) throws IOException { - // Check if we've already collected enough documents + public void intersectLeft(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree, long[] docCount, ResumableDISI.BKDState bkdState) + throws IOException { if (docCount[0] >= size) { // Save current position for resumption if (bkdState != null) { @@ -285,16 +264,13 @@ public void intersectLeft( } return; } - - PointValues.Relation r = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue()); + PointValues.Relation r = visitor.compare(bkdState.getCurrentTree().getMinPackedValue(), bkdState.getCurrentTree().getMaxPackedValue()); if (r == PointValues.Relation.CELL_OUTSIDE_QUERY) { return; } // Handle leaf nodes if (pointTree.moveToChild() == false) { - long docsBefore = docCount[0]; - if (r == PointValues.Relation.CELL_INSIDE_QUERY) { pointTree.visitDocIDs(visitor); } else { @@ -302,21 +278,20 @@ public void intersectLeft( pointTree.visitDocValues(visitor); } - // After visiting docs, check if we hit the limit - if (docCount[0] >= size && bkdState != null) { - // We've processed this leaf and hit the limit - // Find the next unvisited position by moving up the tree - PointValues.PointTree nextPosition = findNextUnvisitedPosition(pointTree); - if (nextPosition != null) { - bkdState.setCurrentTree(nextPosition); - bkdState.setInProgress(true); - } else { - bkdState.setExhausted(true); - } - } +// After visiting docs, check if we hit the limit +// if (docCount[0] >= size && bkdState != null) { +// // We've processed this leaf and hit the limit +// // Find the next unvisited position by moving up the tree +// PointValues.PointTree nextPosition = findNextUnvisitedPosition(pointTree); +// if (nextPosition != null) { +// bkdState.setCurrentTree(nextPosition); +// bkdState.setInProgress(true); +// } else { +// bkdState.setExhausted(true); +// } +// } return; } - // For CELL_INSIDE_QUERY, check if we can skip right child if (r == PointValues.Relation.CELL_INSIDE_QUERY) { long leftSize = pointTree.size(); @@ -329,34 +304,18 @@ public void intersectLeft( return; } } - - // We need both children - clone right child before processing left + // We need both children - now clone right PointValues.PointTree rightChild = null; if (pointTree.moveToSibling()) { rightChild = pointTree.clone(); pointTree.moveToParent(); pointTree.moveToChild(); } - - // Process left child first + // Process both children: left first, then right if needed intersectLeft(visitor, pointTree, docCount, bkdState); - - // If we've hit the limit during left processing, check if right child should be saved - if (docCount[0] >= size) { - if (rightChild != null && bkdState != null && !bkdState.isInProgress()) { - // Left child processing completed but didn't set next position - // Right child is the next unvisited position - bkdState.setCurrentTree(rightChild); - bkdState.setInProgress(true); - } - return; - } - - // Process right child if needed and available - if (rightChild != null && docCount[0] < size) { + if (docCount[0] < size && rightChild != null) { intersectLeft(visitor, rightChild, docCount, bkdState); } - pointTree.moveToParent(); } @@ -365,7 +324,7 @@ private PointValues.PointTree findNextUnvisitedPosition(PointValues.PointTree cu PointValues.PointTree tree = currentLeaf.clone(); // Move up the tree to find the next unvisited sibling or ancestor's sibling - while (tree.moveToParent()) { + if (tree.moveToParent()) { // Try to move to sibling (next unvisited subtree) if (tree.moveToSibling()) { return tree.clone(); @@ -377,126 +336,50 @@ private PointValues.PointTree findNextUnvisitedPosition(PointValues.PointTree cu } // custom intersect visitor to walk the right of tree (from rightmost leaf going left) - public void intersectRight( - PointValues.IntersectVisitor visitor, - PointValues.PointTree pointTree, - long[] docCount, - ResumableDISI.BKDState bkdState - ) throws IOException { - // Check if we've already collected enough documents + public void intersectRight(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree, long[] docCount) + throws IOException { if (docCount[0] >= size) { - // Save current position for resumption - if (bkdState != null) { - bkdState.setCurrentTree(pointTree); - bkdState.setInProgress(true); - } return; } - PointValues.Relation r = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue()); if (r == PointValues.Relation.CELL_OUTSIDE_QUERY) { return; } - // Handle leaf nodes if (pointTree.moveToChild() == false) { - long docsBefore = docCount[0]; - if (r == PointValues.Relation.CELL_INSIDE_QUERY) { pointTree.visitDocIDs(visitor); } else { // CELL_CROSSES_QUERY pointTree.visitDocValues(visitor); } - - // After visiting docs, check if we hit the limit - if (docCount[0] >= size && bkdState != null) { - // We've processed this leaf and hit the limit - // Find the next unvisited position by moving up the tree (for right traversal) - PointValues.PointTree nextPosition = findNextUnvisitedPositionRight(pointTree); - if (nextPosition != null) { - bkdState.setCurrentTree(nextPosition); - bkdState.setInProgress(true); - } else { - bkdState.setExhausted(true); - } - } return; } - // Internal node - get left child reference (we're at left child initially) PointValues.PointTree leftChild = pointTree.clone(); - // Move to right child if it exists boolean hasRightChild = pointTree.moveToSibling(); - // For CELL_INSIDE_QUERY, check if we can skip left child if (r == PointValues.Relation.CELL_INSIDE_QUERY && hasRightChild) { long rightSize = pointTree.size(); long needed = size - docCount[0]; if (rightSize >= needed) { // Right child has all we need - only process right - intersectRight(visitor, pointTree, docCount, bkdState); + intersectRight(visitor, pointTree, docCount); pointTree.moveToParent(); return; } } - - // Process right child first (for DESC order) + // Process both children: right first (for DESC), then left if needed if (hasRightChild) { - intersectRight(visitor, pointTree, docCount, bkdState); - - // If we've hit the limit during right processing, check if left child should be saved - if (docCount[0] >= size) { - if (bkdState != null && !bkdState.isInProgress()) { - // Right child processing completed but didn't set next position - // Left child is the next unvisited position - bkdState.setCurrentTree(leftChild); - bkdState.setInProgress(true); - } - return; - } + intersectRight(visitor, pointTree, docCount); } - - // Process left child if needed if (docCount[0] < size) { - intersectRight(visitor, leftChild, docCount, bkdState); + intersectRight(visitor, leftChild, docCount); } - pointTree.moveToParent(); } - // Helper method to find the next unvisited position for right traversal - private PointValues.PointTree findNextUnvisitedPositionRight(PointValues.PointTree currentLeaf) throws IOException { - PointValues.PointTree tree = currentLeaf.clone(); - - // For right traversal, we need to find the next position going from right to left - // Move up the tree to find the next unvisited left sibling or ancestor's left sibling - while (tree.moveToParent()) { - // Check if we came from the right child - PointValues.PointTree parent = tree.clone(); - if (parent.moveToChild()) { - // We're at left child, check if there's a left sibling to process - // For right traversal, after processing right subtree, we process left subtree - PointValues.PointTree leftSibling = parent.clone(); - if (!isCurrentPosition(leftSibling, tree)) { - // This left child hasn't been processed yet - return leftSibling; - } - } - } - - // No more unvisited positions found - return null; - } - - // Helper to check if two tree positions are the same - private boolean isCurrentPosition(PointValues.PointTree tree1, PointValues.PointTree tree2) { - // Simple comparison - in a real implementation, you'd compare the actual tree positions - // For now, we'll use a conservative approach - return false; // Always assume different positions to be safe - } - @Override public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { LeafReader reader = context.reader(); @@ -509,69 +392,11 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti // values.size(): total points indexed, In most cases: values.size() ≈ number of documents (assuming single-valued fields) if (size > values.size()) { return pointRangeQueryWeight.scorerSupplier(context); - } - - if (sortOrder == null || sortOrder.equals(SortOrder.ASC)) { - return new ScorerSupplier() { - - final DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values); - final PointValues.IntersectVisitor visitor = getIntersectVisitor(result, docCount); - long cost = -1; - - // Create per-shard BKD state to avoid concurrency issues - ResumableDISI.BKDState shardState = new ResumableDISI.BKDState(); - - @Override - public Scorer get(long leadCost) throws IOException { - // Check if we have a saved tree and we're not exhausted - if (shardState.getCurrentTree() == null && !shardState.isExhausted()) { - // First call - start from the root - shardState.setCurrentTree(values.getPointTree()); - docCount[0] = 0; // Reset doc count for first call - } else if (shardState.getCurrentTree() != null && !shardState.isExhausted()) { - // Resume from where we left off - docCount[0] = (int) shardState.getDocCount(); - } - - // Only process if we haven't collected enough documents and we're not exhausted - if (!shardState.isExhausted() && docCount[0] < size && shardState.getCurrentTree() != null) { - // Reset the in-progress flag before processing - shardState.setInProgress(false); - - // Call intersect with the current tree, passing the shard state - intersectLeft(shardState.getCurrentTree(), visitor, docCount, shardState); - - // Update the state's docCount - shardState.setDocCount(docCount[0]); - - // If we didn't collect enough documents and we're not in progress, we've exhausted the tree - if (docCount[0] < size && !shardState.isInProgress()) { - shardState.setExhausted(true); - } - } - - DocIdSetIterator iterator = result.build().iterator(); - return new ConstantScoreScorer(score(), scoreMode, iterator); - } - - @Override - public long cost() { - if (cost == -1) { - // Computing the cost may be expensive, so only do it if necessary - cost = values.estimateDocCount(visitor); - assert cost >= 0; - } - return cost; - } - }; } else { - // we need to fetch size + deleted docs since the collector will prune away deleted docs resulting in fewer results - // than expected - final int deletedDocs = reader.numDeletedDocs(); - size += deletedDocs; - return new ScorerSupplier() { + if (sortOrder == null || sortOrder.equals(SortOrder.ASC)) { + return new ScorerSupplier() { - final DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values); + DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values); final PointValues.IntersectVisitor visitor = getIntersectVisitor(result, docCount); long cost = -1; @@ -586,8 +411,11 @@ public Scorer get(long leadCost) throws IOException { state.setCurrentTree(values.getPointTree()); docCount[0] = 0; // Reset doc count for first call } else if (state.getCurrentTree() != null && !state.isExhausted()) { + result = new DocIdSetBuilder(reader.maxDoc(), values); + // Resume from where we left off docCount[0] = (int) state.getDocCount(); + } // Only process if we haven't collected enough documents and we're not exhausted @@ -595,32 +423,85 @@ public Scorer get(long leadCost) throws IOException { // Reset the in-progress flag before processing state.setInProgress(false); - // Call intersect with the current tree - intersectRight(state.getCurrentTree(), visitor, docCount, state); + // Call intersect with the current tree, passing the shard state + intersectLeft(state.getCurrentTree(), visitor, docCount, state); // Update the state's docCount state.setDocCount(docCount[0]); // If we didn't collect enough documents and we're not in progress, we've exhausted the tree - if (docCount[0] < size && !state.isInProgress()) { - state.setExhausted(true); - } +// if (docCount[0] < size && !state.isInProgress()) { +// state.setExhausted(true); +// } } +// if (state.isExhausted()){ +// return new Scorer() { +// @Override +// public int docID() { +// return 0; +// } +// +// @Override +// public DocIdSetIterator iterator() { +// return null; +// } +// +// @Override +// public float getMaxScore(int upTo) throws IOException { +// return 0; +// } +// +// @Override +// public float score() throws IOException { +// return 0; +// } +// }; +// } DocIdSetIterator iterator = result.build().iterator(); return new ConstantScoreScorer(score(), scoreMode, iterator); } - @Override - public long cost() { - if (cost == -1) { - // Computing the cost may be expensive, so only do it if necessary - cost = values.estimateDocCount(visitor); - assert cost >= 0; + + @Override + public long cost() { + if (cost == -1) { + // Computing the cost may be expensive, so only do it if necessary + cost = values.estimateDocCount(visitor); + assert cost >= 0; + } + return cost; } - return cost; - } - }; + }; + } else { + // we need to fetch size + deleted docs since the collector will prune away deleted docs resulting in fewer results + // than expected + final int deletedDocs = reader.numDeletedDocs(); + size += deletedDocs; + return new ScorerSupplier() { + + final DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values); + final PointValues.IntersectVisitor visitor = getIntersectVisitor(result, docCount); + long cost = -1; + + @Override + public Scorer get(long leadCost) throws IOException { + intersectRight(values.getPointTree(), visitor, docCount); + DocIdSetIterator iterator = result.build().iterator(); + return new ConstantScoreScorer(score(), scoreMode, iterator); + } + + @Override + public long cost() { + if (cost == -1) { + // Computing the cost may be expensive, so only do it if necessary + cost = values.estimateDocCount(visitor); + assert cost >= 0; + } + return cost; + } + }; + } } } From bb5dffb1ca67ca314f50b84ff599a17879f5dba6 Mon Sep 17 00:00:00 2001 From: Sawan Srivastava Date: Thu, 31 Jul 2025 23:49:00 +0000 Subject: [PATCH 20/38] correct state saving logic in intersectLeft Signed-off-by: Sawan Srivastava --- .../ApproximatePointRangeQuery.java | 131 ++++++++---------- .../search/approximate/ResumableDISI.java | 13 +- 2 files changed, 65 insertions(+), 79 deletions(-) diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java index 7e1d3fb45e0ad..7a38bd9196482 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java @@ -30,7 +30,6 @@ import org.apache.lucene.search.Weight; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.DocIdSetBuilder; -import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.IntsRef; import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.mapper.NumericPointEncoder; @@ -241,8 +240,12 @@ private boolean checkValidPointValues(PointValues values) throws IOException { return true; } - private void intersectLeft(PointValues.PointTree pointTree, PointValues.IntersectVisitor visitor, long[] docCount, ResumableDISI.BKDState state) - throws IOException { + private void intersectLeft( + PointValues.PointTree pointTree, + PointValues.IntersectVisitor visitor, + long[] docCount, + ResumableDISI.BKDState state + ) throws IOException { intersectLeft(visitor, pointTree, docCount, state); assert pointTree.moveToParent() == false; } @@ -254,8 +257,12 @@ private void intersectRight(PointValues.PointTree pointTree, PointValues.Interse } // custom intersect visitor to walk the left of the tree - public void intersectLeft(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree, long[] docCount, ResumableDISI.BKDState bkdState) - throws IOException { + public void intersectLeft( + PointValues.IntersectVisitor visitor, + PointValues.PointTree pointTree, + long[] docCount, + ResumableDISI.BKDState bkdState + ) throws IOException { if (docCount[0] >= size) { // Save current position for resumption if (bkdState != null) { @@ -264,7 +271,7 @@ public void intersectLeft(PointValues.IntersectVisitor visitor, PointValues.Poin } return; } - PointValues.Relation r = visitor.compare(bkdState.getCurrentTree().getMinPackedValue(), bkdState.getCurrentTree().getMaxPackedValue()); + PointValues.Relation r = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue()); if (r == PointValues.Relation.CELL_OUTSIDE_QUERY) { return; } @@ -278,18 +285,6 @@ public void intersectLeft(PointValues.IntersectVisitor visitor, PointValues.Poin pointTree.visitDocValues(visitor); } -// After visiting docs, check if we hit the limit -// if (docCount[0] >= size && bkdState != null) { -// // We've processed this leaf and hit the limit -// // Find the next unvisited position by moving up the tree -// PointValues.PointTree nextPosition = findNextUnvisitedPosition(pointTree); -// if (nextPosition != null) { -// bkdState.setCurrentTree(nextPosition); -// bkdState.setInProgress(true); -// } else { -// bkdState.setExhausted(true); -// } -// } return; } // For CELL_INSIDE_QUERY, check if we can skip right child @@ -313,6 +308,13 @@ public void intersectLeft(PointValues.IntersectVisitor visitor, PointValues.Poin } // Process both children: left first, then right if needed intersectLeft(visitor, pointTree, docCount, bkdState); + if (docCount[0] >= size && !bkdState.hasSetTree()) { + if (bkdState != null) { + bkdState.setCurrentTree(pointTree); + bkdState.setInProgress(true); + bkdState.setHasSetTree(true); + } + } if (docCount[0] < size && rightChild != null) { intersectLeft(visitor, rightChild, docCount, bkdState); } @@ -396,72 +398,49 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti if (sortOrder == null || sortOrder.equals(SortOrder.ASC)) { return new ScorerSupplier() { - DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values); - final PointValues.IntersectVisitor visitor = getIntersectVisitor(result, docCount); - long cost = -1; - - // Create per-shard BKD state to avoid concurrency issues - ResumableDISI.BKDState state = new ResumableDISI.BKDState(); + DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values); + final PointValues.IntersectVisitor visitor = getIntersectVisitor(result, docCount); + long cost = -1; - @Override - public Scorer get(long leadCost) throws IOException { - // Check if we have a saved tree and we're not exhausted - if (state.getCurrentTree() == null && !state.isExhausted()) { - // First call - start from the root - state.setCurrentTree(values.getPointTree()); - docCount[0] = 0; // Reset doc count for first call - } else if (state.getCurrentTree() != null && !state.isExhausted()) { - result = new DocIdSetBuilder(reader.maxDoc(), values); - - // Resume from where we left off - docCount[0] = (int) state.getDocCount(); + // Create per-shard BKD state to avoid concurrency issues + ResumableDISI.BKDState state = new ResumableDISI.BKDState(); - } + @Override + public Scorer get(long leadCost) throws IOException { + // Check if we have a saved tree and we're not exhausted + if (state.getCurrentTree() == null && !state.isExhausted()) { + // First call - start from the root + state.setCurrentTree(values.getPointTree()); + docCount[0] = 0; // Reset doc count for first call + } else if (state.getCurrentTree() != null && !state.isExhausted()) { + result = new DocIdSetBuilder(reader.maxDoc(), values); + + state.setHasSetTree(false); + // Resume from where we left off + docCount[0] = (int) state.getDocCount(); - // Only process if we haven't collected enough documents and we're not exhausted - if (!state.isExhausted() && docCount[0] < size && state.getCurrentTree() != null) { - // Reset the in-progress flag before processing - state.setInProgress(false); + } - // Call intersect with the current tree, passing the shard state - intersectLeft(state.getCurrentTree(), visitor, docCount, state); + // Only process if we haven't collected enough documents and we're not exhausted + if (!state.isExhausted() && docCount[0] < size && state.getCurrentTree() != null) { + // Reset the in-progress flag before processing + state.setInProgress(false); - // Update the state's docCount - state.setDocCount(docCount[0]); + // Call intersect with the current tree, passing the shard state + intersectLeft(state.getCurrentTree(), visitor, docCount, state); - // If we didn't collect enough documents and we're not in progress, we've exhausted the tree -// if (docCount[0] < size && !state.isInProgress()) { -// state.setExhausted(true); -// } - } + // Update the state's docCount + state.setDocCount(docCount[0]); -// if (state.isExhausted()){ -// return new Scorer() { -// @Override -// public int docID() { -// return 0; -// } -// -// @Override -// public DocIdSetIterator iterator() { -// return null; -// } -// -// @Override -// public float getMaxScore(int upTo) throws IOException { -// return 0; -// } -// -// @Override -// public float score() throws IOException { -// return 0; -// } -// }; -// } - DocIdSetIterator iterator = result.build().iterator(); - return new ConstantScoreScorer(score(), scoreMode, iterator); - } + // If we didn't collect enough documents and we're not in progress, we've exhausted the tree + if (docCount[0] < size && !state.isInProgress()) { + state.setExhausted(true); + } + } + DocIdSetIterator iterator = result.build().iterator(); + return new ConstantScoreScorer(score(), scoreMode, iterator); + } @Override public long cost() { diff --git a/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java b/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java index 3fda22516271d..bc791b1435803 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java +++ b/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java @@ -34,8 +34,6 @@ public class ResumableDISI extends DocIdSetIterator { private int currentDocId = -1; private boolean fullyExhausted = false; - // BKD traversal state for approximatable queries - private BKDState bkdState; private int documentsScored = 0; // Total documents scored across all expansions /** @@ -56,7 +54,6 @@ public ResumableDISI(ScorerSupplier scorerSupplier) { public ResumableDISI(ScorerSupplier scorerSupplier, int expansionSize) { this.scorerSupplier = scorerSupplier; this.expansionSize = expansionSize; - this.bkdState = new BKDState(); } @Override @@ -215,6 +212,7 @@ public static class BKDState { private boolean isExhausted = false; private long docCount = 0; private boolean inProgress = false; + private boolean hasSetTree = false; public PointValues.PointTree getCurrentTree() { return currentTree; @@ -251,5 +249,14 @@ public boolean isInProgress() { public void setInProgress(boolean inProgress) { this.inProgress = inProgress; } + + public boolean hasSetTree() { + return hasSetTree; + } + + public void setHasSetTree(boolean hasSet) { + this.hasSetTree = hasSetTree; + } + } } From 29a60906b03fe37602b661eda604aef01474bc4d Mon Sep 17 00:00:00 2001 From: Sawan Srivastava Date: Thu, 7 Aug 2025 23:15:34 +0000 Subject: [PATCH 21/38] before iterative tree traversal for resumability Signed-off-by: Sawan Srivastava --- .../ApproximateBooleanScorerSupplier.java | 154 +++++++------- .../ApproximatePointRangeQuery.java | 188 +++++++++++++----- .../search/approximate/ResumableDISI.java | 47 ++++- 3 files changed, 260 insertions(+), 129 deletions(-) diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java index 3817d8bb4d0d9..7ead898d4526d 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java @@ -114,6 +114,32 @@ public int docID() { } }; } + /** + * Get an estimate of the {@link Scorer} that would be returned by {@link #get}. + */ + @Override + public long cost() { + if (cost == -1) { + // Estimate cost as the minimum cost of all clauses (conjunction) + if (!clauseWeights.isEmpty()) { + cost = Long.MAX_VALUE; + for (Weight weight : clauseWeights) { + try { + ScorerSupplier supplier = weight.scorerSupplier(context); + if (supplier != null) { + cost = Math.min(cost, supplier.cost()); + } + } catch (IOException e) { + // If we can't get the cost, use a default + cost = Math.min(cost, 1000); + } + } + } else { + cost = 0; + } + } + return cost; + } /** * Get a scorer that is optimized for bulk-scoring. @@ -126,114 +152,96 @@ public BulkScorer bulkScorer() throws IOException { // Create appropriate iterators for each clause - ResumableDISI only for approximatable queries List clauseIterators = new ArrayList<>(clauseWeights.size()); + System.out.println("DEBUG: Creating iterators for " + clauseWeights.size() + " clauses"); + for (Weight weight : clauseWeights) { Query query = weight.getQuery(); ScorerSupplier supplier = weight.scorerSupplier(context); + System.out.println("DEBUG: Processing query: " + query.getClass().getSimpleName() + " - " + query); if (query instanceof ApproximateQuery) { // Use ResumableDISI for approximatable queries + System.out.println("DEBUG: Using ResumableDISI for ApproximateQuery"); ResumableDISI disi = new ResumableDISI(supplier); clauseIterators.add(disi); } else { // Use regular DocIdSetIterator for non-approximatable queries + System.out.println("DEBUG: Using regular DISI for non-approximatable query"); Scorer scorer = supplier.get(supplier.cost()); - clauseIterators.add(scorer.iterator()); + DocIdSetIterator iterator = scorer.iterator(); + System.out.println("DEBUG: Regular iterator cost: " + iterator.cost()); + clauseIterators.add(iterator); } } - // Use Lucene's ConjunctionUtils to create the conjunction - DocIdSetIterator conjunctionDISI = ConjunctionUtils.intersectIterators(clauseIterators); + // Use Lucene's ConjunctionUtils to create the conjunction ONCE (outside the BulkScorer) + final DocIdSetIterator conjunctionDISI = ConjunctionUtils.intersectIterators(clauseIterators); + // Create a simple scorer for the collector + Scorer scorer = new Scorer() { + @Override + public DocIdSetIterator iterator() { + return conjunctionDISI; + } - // Create a simple bulk scorer that wraps the conjunction - return new BulkScorer() { @Override - public int score(LeafCollector collector, Bits acceptDocs, int min, int max) throws IOException { - // Create a simple scorer for the collector - Scorer scorer = new Scorer() { - @Override - public DocIdSetIterator iterator() { - return conjunctionDISI; - } + public float score() throws IOException { + return 0.0f; + } - @Override - public float score() throws IOException { - return boost; - } + @Override + public float getMaxScore(int upTo) throws IOException { + return 0.0f; + } - @Override - public float getMaxScore(int upTo) throws IOException { - return boost; - } + @Override + public int docID() { + return conjunctionDISI.docID(); + } + }; - @Override - public int docID() { - return conjunctionDISI.docID(); - } - }; + // Create a simple bulk scorer that wraps the conjunction + return new BulkScorer() { + @Override + public int score(LeafCollector collector, Bits acceptDocs, int min, int max) throws IOException { collector.setScorer(scorer); - // Track how many documents we've collected - int collected = 0; - int docID; - - // Continue collecting until we reach the threshold - while (collected < threshold) { - // Get the next document from the conjunction - docID = conjunctionDISI.nextDoc(); - - if (docID == DocIdSetIterator.NO_MORE_DOCS) { - // No more documents - ResumableDISIs will expand internally if possible - break; + // Position the iterator correctly (following Lucene's DefaultBulkScorer pattern) + if (conjunctionDISI.docID() < min) { + if (conjunctionDISI.docID() == min - 1) { + conjunctionDISI.nextDoc(); } - - if (docID >= max) { - // We've reached the end of the range - return docID; + else { + conjunctionDISI.advance(min); } - - if (docID >= min && (acceptDocs == null || acceptDocs.get(docID))) { - // Collect the document - collector.collect(docID); + } + int collected = 0; + int doc = -1; + + // Score documents in the range [min, max) following Lucene's pattern + // Note: No threshold limit here - that's handled by individual ResumableDISI clauses + for (doc = conjunctionDISI.docID(); doc < max; doc = conjunctionDISI.nextDoc()) { + if (acceptDocs == null || acceptDocs.get(doc)) { + System.out.println("Conjunction Hit: "+doc); + collector.collect(doc); collected++; } } - // We've either collected enough documents or exhausted all possibilities - return DocIdSetIterator.NO_MORE_DOCS; + // Return the current iterator position (standard Lucene pattern) + return conjunctionDISI.docID(); } + @Override public long cost() { return ApproximateBooleanScorerSupplier.this.cost(); } }; - } - /** - * Get an estimate of the {@link Scorer} that would be returned by {@link #get}. - */ - @Override - public long cost() { - if (cost == -1) { - // Estimate cost as the minimum cost of all clauses (conjunction) - if (!clauseWeights.isEmpty()) { - cost = Long.MAX_VALUE; - for (Weight weight : clauseWeights) { - try { - ScorerSupplier supplier = weight.scorerSupplier(context); - if (supplier != null) { - cost = Math.min(cost, supplier.cost()); - } - } catch (IOException e) { - // If we can't get the cost, use a default - cost = Math.min(cost, 1000); - } - } - } else { - cost = 0; - } - } - return cost; } } + + + + diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java index 7a38bd9196482..2329305d7276c 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java @@ -171,11 +171,19 @@ public void grow(int count) { adder = result.grow(count); } + @Override public void visit(int docID) { // it is possible that size < 1024 and docCount < size but we will continue to count through all the 1024 docs + if (docCount[0] == 0) { + System.out.println(docID); + } adder.add(docID); docCount[0]++; + if (docCount[0] == 10241) { + System.out.println(docID); +// System.out.println(docCount[0]); + } } @Override @@ -185,8 +193,19 @@ public void visit(DocIdSetIterator iterator) throws IOException { @Override public void visit(IntsRef ref) { + if (docCount[0] == 0){ + System.out.println(ref.ints[0]); + } + if (docCount[0] == 10240) { + System.out.println(ref.ints[0]); +// System.out.println(docCount[0]); + } adder.add(ref); docCount[0] += ref.length; + if (docCount[0] == 10240) { + System.out.println(ref.ints[0]); +// System.out.println(docCount[0]); + } } @Override @@ -246,7 +265,8 @@ private void intersectLeft( long[] docCount, ResumableDISI.BKDState state ) throws IOException { - intersectLeft(visitor, pointTree, docCount, state); + intersectLeft(visitor, pointTree, docCount, state, true); // Top-level call + // Only assert for complete traversals (top-level calls) assert pointTree.moveToParent() == false; } @@ -261,18 +281,25 @@ public void intersectLeft( PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree, long[] docCount, - ResumableDISI.BKDState bkdState + ResumableDISI.BKDState bkdState, + boolean isTopLevel ) throws IOException { if (docCount[0] >= size) { // Save current position for resumption if (bkdState != null) { - bkdState.setCurrentTree(pointTree); + bkdState.needMore = true; + bkdState.setPointTree(pointTree); bkdState.setInProgress(true); } return; } PointValues.Relation r = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue()); if (r == PointValues.Relation.CELL_OUTSIDE_QUERY) { +// if (bkdState != null) { +// bkdState.needMore = true; +// bkdState.setCurrentTree(pointTree); +// bkdState.setInProgress(true); +// } return; } @@ -285,6 +312,16 @@ public void intersectLeft( pointTree.visitDocValues(visitor); } + // After processing leaf, check if we need to save state for resumption + if (docCount[0] >= size) { + // We have enough documents - save state and signal need for more + if (bkdState != null) { + bkdState.setPointTree(pointTree); + bkdState.setInProgress(true); + bkdState.needMore = true; + } + } + return; } // For CELL_INSIDE_QUERY, check if we can skip right child @@ -294,8 +331,18 @@ public void intersectLeft( if (leftSize >= needed) { // Process only left child - intersectLeft(visitor, pointTree, docCount, bkdState); - pointTree.moveToParent(); + intersectLeft(visitor, pointTree, docCount, bkdState, isTopLevel); // Pass through isTopLevel + if (docCount[0] >= size) { + // We have enough documents - save state and signal need for more + if (bkdState != null) { + bkdState.setPointTree(pointTree); + bkdState.setInProgress(true); + bkdState.needMore = true; + } + } + if (isTopLevel) { + pointTree.moveToParent(); + } return; } } @@ -307,35 +354,22 @@ public void intersectLeft( pointTree.moveToChild(); } // Process both children: left first, then right if needed - intersectLeft(visitor, pointTree, docCount, bkdState); - if (docCount[0] >= size && !bkdState.hasSetTree()) { - if (bkdState != null) { - bkdState.setCurrentTree(pointTree); - bkdState.setInProgress(true); - bkdState.setHasSetTree(true); - } - } + intersectLeft(visitor, pointTree, docCount, bkdState, isTopLevel); // Pass through isTopLevel + if (docCount[0] >= size) { + // We have enough documents - save state and signal need for more + if (bkdState != null) { + bkdState.setPointTree(pointTree); + bkdState.setInProgress(true); + bkdState.needMore = true; + } + } if (docCount[0] < size && rightChild != null) { - intersectLeft(visitor, rightChild, docCount, bkdState); + intersectLeft(visitor, rightChild, docCount, bkdState, isTopLevel); // Pass through isTopLevel } - pointTree.moveToParent(); + // Only call moveToParent() for top-level calls + pointTree.moveToParent(); } - // Helper method to find the next unvisited position after processing a leaf - private PointValues.PointTree findNextUnvisitedPosition(PointValues.PointTree currentLeaf) throws IOException { - PointValues.PointTree tree = currentLeaf.clone(); - - // Move up the tree to find the next unvisited sibling or ancestor's sibling - if (tree.moveToParent()) { - // Try to move to sibling (next unvisited subtree) - if (tree.moveToSibling()) { - return tree.clone(); - } - } - - // No more unvisited positions found - return null; - } // custom intersect visitor to walk the right of tree (from rightmost leaf going left) public void intersectRight(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree, long[] docCount) @@ -398,47 +432,99 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti if (sortOrder == null || sortOrder.equals(SortOrder.ASC)) { return new ScorerSupplier() { + ResumableDISI.BKDState state = new ResumableDISI.BKDState(); + // Keep a visitor for cost estimation only DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values); - final PointValues.IntersectVisitor visitor = getIntersectVisitor(result, docCount); + PointValues.IntersectVisitor visitor = getIntersectVisitor(result, docCount); long cost = -1; - // Create per-shard BKD state to avoid concurrency issues - ResumableDISI.BKDState state = new ResumableDISI.BKDState(); @Override public Scorer get(long leadCost) throws IOException { + System.out.println("DEBUG: ApproximatePointRangeQuery.get() called - BKD state management disabled"); + +// // Create fresh DocIdSetBuilder and visitor for each call + result = new DocIdSetBuilder(reader.maxDoc(), values); + visitor = getIntersectVisitor(result, docCount); + + // Simple approach: always traverse from root + System.out.println("DEBUG: Starting BKD traversal from root, target size: " + size); +// values.intersect(visitor); + // Check if we have a saved tree and we're not exhausted - if (state.getCurrentTree() == null && !state.isExhausted()) { + if (state.getPointTree() == null && !state.isExhausted()) { // First call - start from the root - state.setCurrentTree(values.getPointTree()); + System.out.println("DEBUG: First call - starting from root"); + state.setPointTree(values.getPointTree()); docCount[0] = 0; // Reset doc count for first call - } else if (state.getCurrentTree() != null && !state.isExhausted()) { - result = new DocIdSetBuilder(reader.maxDoc(), values); - - state.setHasSetTree(false); - // Resume from where we left off - docCount[0] = (int) state.getDocCount(); - + } else if (state.getPointTree() != null && !state.isExhausted()) { + // Resume from where we left off - but reset docCount for this expansion + System.out.println("DEBUG: Resuming from saved state"); + docCount[0] = 0; + } else { + System.out.println("DEBUG: BKD state is exhausted or invalid"); } - // Only process if we haven't collected enough documents and we're not exhausted - if (!state.isExhausted() && docCount[0] < size && state.getCurrentTree() != null) { + // Only process if we're not exhausted and have a valid tree + if (!state.isExhausted() && state.getPointTree() != null) { + System.out.println("DEBUG: Processing BKD tree, current docCount: " + docCount[0] + ", size: " + size); // Reset the in-progress flag before processing state.setInProgress(false); // Call intersect with the current tree, passing the shard state - intersectLeft(state.getCurrentTree(), visitor, docCount, state); + System.out.println(context.isTopLevel); + intersectLeft(visitor, state.getPointTree(), docCount, state, context.isTopLevel); // Resumable call + System.out.println("DEBUG: After intersectLeft, docCount: " + docCount[0] + ", inProgress: " + state.isInProgress()); + if (docCount[0] == 0 && !state.isInProgress()) { + System.out.println("DEBUG: Setting BKD state to exhausted because no documents found and not inProgress"); + state.setExhausted(true); + } - // Update the state's docCount - state.setDocCount(docCount[0]); + System.out.println("After intersect left, can tree move to parent?"+ state.getPointTree().moveToParent()); - // If we didn't collect enough documents and we're not in progress, we've exhausted the tree - if (docCount[0] < size && !state.isInProgress()) { - state.setExhausted(true); + if (!context.isTopLevel) { + if (!state.getPointTree().moveToSibling()) { + // No more siblings - try to move up and find next unvisited subtree + state.getPointTree().moveToParent(); + state.getPointTree().moveToSibling(); + } } - } + state.needMore = false; + while (docCount[0] < size && !state.needMore) { + // Reset needMore for this iteration + state.needMore = false; + + + if (!state.getPointTree().moveToSibling()) { + // No more siblings - try to move up and find next unvisited subtree + if (!state.getPointTree().moveToParent()) { + // Reached root, no more nodes to process + break; + } + // + if (!state.getPointTree().moveToSibling()) { + // No sibling at parent level either, we're done + break; + } + // + intersectLeft(visitor, state.getPointTree(), docCount, state, true); // Resumable call + } + // + intersectLeft(visitor, state.getPointTree(), docCount, state, false); // Resumable call + } + System.out.println("DEBUG: BKD traversal completed, found " + docCount[0] + " documents"); + + docCount[0] = 0; + + // If we didn't collect any documents and we're not in progress, we've exhausted the tree + + } else { + System.out.println("DEBUG: Skipping BKD processing - exhausted: " + state.isExhausted()); + } DocIdSetIterator iterator = result.build().iterator(); + result = null; + System.out.println("DEBUG: Built iterator with cost: " + iterator.cost()); return new ConstantScoreScorer(score(), scoreMode, iterator); } diff --git a/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java b/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java index bc791b1435803..f326478e910d8 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java +++ b/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java @@ -35,6 +35,11 @@ public class ResumableDISI extends DocIdSetIterator { private boolean fullyExhausted = false; private int documentsScored = 0; // Total documents scored across all expansions + private int documentsReturned = 0; // Count of documents returned by nextDoc() + + // Debug: Add a unique ID to distinguish between ResumableDISI instances + private static int instanceCounter = 0; + private final int instanceId; /** * Creates a new ResumableDISI with the default expansion size of 10,000 documents. @@ -54,6 +59,8 @@ public ResumableDISI(ScorerSupplier scorerSupplier) { public ResumableDISI(ScorerSupplier scorerSupplier, int expansionSize) { this.scorerSupplier = scorerSupplier; this.expansionSize = expansionSize; + this.instanceId = ++instanceCounter; + System.out.println("DEBUG: Created ResumableDISI instance " + instanceId); } @Override @@ -73,6 +80,7 @@ public int nextDoc() throws IOException { return NO_MORE_DOCS; } // expandInternally() already positioned us on the first document + documentsReturned++; return currentDocId; } @@ -81,39 +89,50 @@ public int nextDoc() throws IOException { if (doc != NO_MORE_DOCS) { currentDocId = doc; + documentsReturned++; return doc; } // Current iterator exhausted, try to expand internally if (expandInternally()) { // expandInternally() already positioned us on the first document of the new batch + documentsReturned++; return currentDocId; } // No more expansion possible +// System.out.println("DEBUG: ResumableDISI " + instanceId + " - EXHAUSTED after returning " + documentsReturned + " total documents"); currentDocId = NO_MORE_DOCS; return NO_MORE_DOCS; } @Override public int advance(int target) throws IOException { +// System.out.println("DEBUG: ResumableDISI " + instanceId + " - advance(" + target + ") called, currentDocId: " + currentDocId); + if (fullyExhausted) { +// System.out.println("DEBUG: ResumableDISI " + instanceId + " - advance() called but already exhausted"); return NO_MORE_DOCS; } // If we don't have a current iterator, get one if (currentDisi == null) { if (!expandInternally()) { +// System.out.println("DEBUG: ResumableDISI " + instanceId + " - advance() expandInternally failed"); return NO_MORE_DOCS; } // If the first document is >= target, we're good if (currentDocId >= target) { +// System.out.println("DEBUG: ResumableDISI " + instanceId + " - advance() first doc " + currentDocId + " >= target " + target); + // Don't increment documentsReturned - it was already counted in expandInternally() return currentDocId; } // Otherwise, advance to target int doc = currentDisi.advance(target); if (doc != NO_MORE_DOCS) { currentDocId = doc; + // Don't increment documentsReturned here either - advance() skips documents, doesn't return them one by one +// System.out.println("DEBUG: ResumableDISI " + instanceId + " - advance() found doc " + doc + " >= target " + target); return doc; } // Fall through to try expansion @@ -122,26 +141,34 @@ public int advance(int target) throws IOException { int doc = currentDisi.advance(target); if (doc != NO_MORE_DOCS) { currentDocId = doc; + // Don't increment documentsReturned - advance() skips documents, doesn't return them one by one +// System.out.println("DEBUG: ResumableDISI " + instanceId + " - advance() existing iterator found doc " + doc + " >= target " + target); return doc; } // Current iterator exhausted, try to expand +// System.out.println("DEBUG: ResumableDISI " + instanceId + " - advance() current iterator exhausted, trying to expand"); } // Current iterator exhausted, try to expand internally if (expandInternally()) { // If the first document of new batch is >= target, we're good if (currentDocId >= target) { + // Don't increment documentsReturned - it was already counted in expandInternally() +// System.out.println("DEBUG: ResumableDISI " + instanceId + " - advance() expanded, first doc " + currentDocId + " >= target " + target); return currentDocId; } // Otherwise, advance to target int doc = currentDisi.advance(target); if (doc != NO_MORE_DOCS) { currentDocId = doc; + // Don't increment documentsReturned - advance() skips documents +// System.out.println("DEBUG: ResumableDISI " + instanceId + " - advance() expanded and found doc " + doc + " >= target " + target); return doc; } } // No more expansion possible + System.out.println("DEBUG: ResumableDISI " + instanceId + " - advance() EXHAUSTED after returning " + documentsReturned + " total documents"); currentDocId = NO_MORE_DOCS; fullyExhausted = true; return NO_MORE_DOCS; @@ -159,6 +186,13 @@ private boolean expandInternally() throws IOException { return false; } +// // For now, disable expansion after first call to test basic logic +// if (currentDisi != null) { +// System.out.println("DEBUG: ResumableDISI " + instanceId + " - expansion disabled, but NOT marking as exhausted"); +// // Don't set fullyExhausted = true here! Let the iterator continue with its current batch +// return false; +// } + // Get a new scorer from the supplier - this will resume from saved BKD state Scorer scorer = scorerSupplier.get(scorerSupplier.cost()); if (scorer == null) { @@ -169,6 +203,8 @@ private boolean expandInternally() throws IOException { currentDisi = scorer.iterator(); documentsScored += expansionSize; // Track total documents scored + System.out.println("DEBUG: ResumableDISI " + instanceId + " - got iterator with " + currentDisi.cost() + " documents"); + // Check if the new iterator has any documents int firstDoc = currentDisi.nextDoc(); if (firstDoc == NO_MORE_DOCS) { @@ -183,7 +219,7 @@ private boolean expandInternally() throws IOException { @Override public long cost() { - return scorerSupplier.cost(); + return 10_000L; } /** @@ -213,14 +249,15 @@ public static class BKDState { private long docCount = 0; private boolean inProgress = false; private boolean hasSetTree = false; + public boolean needMore = false; - public PointValues.PointTree getCurrentTree() { + public PointValues.PointTree getPointTree() { return currentTree; } - public void setCurrentTree(PointValues.PointTree tree) { + public void setPointTree(PointValues.PointTree tree) { if (tree != null) { - this.currentTree = tree.clone(); + this.currentTree = tree; } else { this.currentTree = null; } @@ -255,7 +292,7 @@ public boolean hasSetTree() { } public void setHasSetTree(boolean hasSet) { - this.hasSetTree = hasSetTree; + this.hasSetTree = hasSet; } } From 69c5b4c2bb83e308e9cd75f05d80675e5a8ae064 Mon Sep 17 00:00:00 2001 From: Sawan Srivastava Date: Fri, 8 Aug 2025 00:17:41 +0000 Subject: [PATCH 22/38] iterative bkd traversal Signed-off-by: Sawan Srivastava --- .../ApproximateBooleanScorerSupplier.java | 3 +- .../ApproximatePointRangeQuery.java | 113 ++++++++++++------ 2 files changed, 76 insertions(+), 40 deletions(-) diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java index 7ead898d4526d..b7d3b01778b08 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java @@ -222,12 +222,13 @@ public int score(LeafCollector collector, Bits acceptDocs, int min, int max) thr // Note: No threshold limit here - that's handled by individual ResumableDISI clauses for (doc = conjunctionDISI.docID(); doc < max; doc = conjunctionDISI.nextDoc()) { if (acceptDocs == null || acceptDocs.get(doc)) { - System.out.println("Conjunction Hit: "+doc); collector.collect(doc); collected++; } } + System.out.println("Num conjunction hits "+collected); + // Return the current iterator position (standard Lucene pattern) return conjunctionDISI.docID(); } diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java index 2329305d7276c..ccddd1b138478 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java @@ -370,6 +370,69 @@ public void intersectLeft( pointTree.moveToParent(); } + public void intersectLeftIterative(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree, long[] docCount, ResumableDISI.BKDState state) throws IOException { + + while (true) { + System.out.println("Doc count: "+docCount[0]); + if (docCount[0] >= size) { + state.setPointTree(pointTree); + state.setInProgress(true); + return; + } + + PointValues.Relation compare = + visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue()); + if (compare == PointValues.Relation.CELL_INSIDE_QUERY) { + // This cell is fully inside the query shape: recursively add all points in this cell + // without filtering + pointTree.visitDocIDs(visitor); + if (docCount[0] >= size) { + state.setPointTree(pointTree); + state.setInProgress(true); + return; + } + } else if (compare == PointValues.Relation.CELL_CROSSES_QUERY) { + // The cell crosses the shape boundary, or the cell fully contains the query, so we fall + // through and do full filtering: + if (pointTree.moveToChild()) { + if (docCount[0] >= size) { + state.setPointTree(pointTree); + state.setInProgress(true); + return; + } + continue; + } + if (docCount[0] >= size) { + state.setPointTree(pointTree); + state.setInProgress(true); + return; + } + // TODO: we can assert that the first value here in fact matches what the pointTree + // claimed? + // Leaf node; scan and filter all points in this block: + pointTree.visitDocValues(visitor); + } + // position ourself to next place + while (pointTree.moveToSibling() == false) { + if (docCount[0] >= size) { + state.setPointTree(pointTree); + state.setInProgress(true); + return; + } + if (pointTree.moveToParent() == false) { + if (docCount[0] >= size) { + state.setPointTree(pointTree); + state.setInProgress(true); + return; + } + return; + } + } + + + + } + } // custom intersect visitor to walk the right of tree (from rightmost leaf going left) public void intersectRight(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree, long[] docCount) @@ -471,48 +534,20 @@ public Scorer get(long leadCost) throws IOException { // Reset the in-progress flag before processing state.setInProgress(false); - // Call intersect with the current tree, passing the shard state - System.out.println(context.isTopLevel); - intersectLeft(visitor, state.getPointTree(), docCount, state, context.isTopLevel); // Resumable call - System.out.println("DEBUG: After intersectLeft, docCount: " + docCount[0] + ", inProgress: " + state.isInProgress()); - if (docCount[0] == 0 && !state.isInProgress()) { - System.out.println("DEBUG: Setting BKD state to exhausted because no documents found and not inProgress"); + // Use intersectLeftIterative for resumable traversal + System.out.println("DEBUG: Starting intersectLeftIterative, current docCount: " + docCount[0]); + intersectLeftIterative(visitor, state.getPointTree(), docCount, state); + System.out.println("DEBUG: After intersectLeftIterative, docCount: " + docCount[0] + ", size: " + size); + + // Check if we collected enough documents + if (docCount[0] >= size) { + state.setInProgress(true); + state.needMore = true; + } else { + // If we didn't reach the size limit, we've exhausted the tree state.setExhausted(true); } - System.out.println("After intersect left, can tree move to parent?"+ state.getPointTree().moveToParent()); - - if (!context.isTopLevel) { - if (!state.getPointTree().moveToSibling()) { - // No more siblings - try to move up and find next unvisited subtree - state.getPointTree().moveToParent(); - state.getPointTree().moveToSibling(); - } - } - - state.needMore = false; - while (docCount[0] < size && !state.needMore) { - // Reset needMore for this iteration - state.needMore = false; - - - if (!state.getPointTree().moveToSibling()) { - // No more siblings - try to move up and find next unvisited subtree - if (!state.getPointTree().moveToParent()) { - // Reached root, no more nodes to process - break; - } - // - if (!state.getPointTree().moveToSibling()) { - // No sibling at parent level either, we're done - break; - } - // - intersectLeft(visitor, state.getPointTree(), docCount, state, true); // Resumable call - } - // - intersectLeft(visitor, state.getPointTree(), docCount, state, false); // Resumable call - } System.out.println("DEBUG: BKD traversal completed, found " + docCount[0] + " documents"); docCount[0] = 0; From 51389fd8e371acdbd0113de334d245fa9f54db8c Mon Sep 17 00:00:00 2001 From: Sawan Srivastava Date: Sat, 9 Aug 2025 04:04:17 +0000 Subject: [PATCH 23/38] more updates Signed-off-by: Sawan Srivastava --- .../approximate/ApproximateBooleanQuery.java | 31 ++--- .../ApproximateBooleanScorerSupplier.java | 27 ++-- .../ApproximatePointRangeQuery.java | 118 +++++++++--------- .../search/approximate/ResumableDISI.java | 97 ++++++++------ 4 files changed, 154 insertions(+), 119 deletions(-) diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java index 631a93a9f2131..716b196d60e13 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java @@ -79,21 +79,22 @@ protected boolean canApproximate(SearchContext context) { return false; } - // For single clause boolean queries, check if the clause can be approximated - if (clauses.size() == 1 && clauses.get(0).occur() != BooleanClause.Occur.MUST_NOT) { - BooleanClause singleClause = clauses.get(0); - Query clauseQuery = singleClause.query(); - - // If the clause is already an ApproximateScoreQuery, we can approximate + set context - if (clauseQuery instanceof ApproximateScoreQuery approximateScoreQuery) { - if (approximateScoreQuery.getApproximationQuery() instanceof ApproximateBooleanQuery nestedBool) { - return nestedBool.canApproximate(context); - } - return approximateScoreQuery.getApproximationQuery().canApproximate(context); - } - } - - return clauses.size() > 1 && clauses.stream().allMatch(clause -> clause.occur() == BooleanClause.Occur.FILTER); +// // For single clause boolean queries, check if the clause can be approximated +// if (clauses.size() == 1 && clauses.get(0).occur() != BooleanClause.Occur.MUST_NOT) { +// BooleanClause singleClause = clauses.get(0); +// Query clauseQuery = singleClause.query(); +// +// // If the clause is already an ApproximateScoreQuery, we can approximate + set context +// if (clauseQuery instanceof ApproximateScoreQuery approximateScoreQuery) { +// if (approximateScoreQuery.getApproximationQuery() instanceof ApproximateBooleanQuery nestedBool) { +// return nestedBool.canApproximate(context); +// } +// return approximateScoreQuery.getApproximationQuery().canApproximate(context); +// } +// } + +// return clauses.size() > 1 && clauses.stream().allMatch(clause -> clause.occur() == BooleanClause.Occur.FILTER); + return clauses.stream().allMatch(clause -> clause.occur() == BooleanClause.Occur.FILTER); } @Override diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java index b7d3b01778b08..a2a52fac0ad1c 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java @@ -152,24 +152,24 @@ public BulkScorer bulkScorer() throws IOException { // Create appropriate iterators for each clause - ResumableDISI only for approximatable queries List clauseIterators = new ArrayList<>(clauseWeights.size()); - System.out.println("DEBUG: Creating iterators for " + clauseWeights.size() + " clauses"); +// System.out.println("DEBUG: Creating iterators for " + clauseWeights.size() + " clauses"); for (Weight weight : clauseWeights) { Query query = weight.getQuery(); ScorerSupplier supplier = weight.scorerSupplier(context); - System.out.println("DEBUG: Processing query: " + query.getClass().getSimpleName() + " - " + query); +// System.out.println("DEBUG: Processing query: " + query.getClass().getSimpleName() + " - " + query); if (query instanceof ApproximateQuery) { // Use ResumableDISI for approximatable queries - System.out.println("DEBUG: Using ResumableDISI for ApproximateQuery"); +// System.out.println("DEBUG: Using ResumableDISI for ApproximateQuery"); ResumableDISI disi = new ResumableDISI(supplier); clauseIterators.add(disi); } else { // Use regular DocIdSetIterator for non-approximatable queries - System.out.println("DEBUG: Using regular DISI for non-approximatable query"); +// System.out.println("DEBUG: Using regular DISI for non-approximatable query"); Scorer scorer = supplier.get(supplier.cost()); DocIdSetIterator iterator = scorer.iterator(); - System.out.println("DEBUG: Regular iterator cost: " + iterator.cost()); +// System.out.println("DEBUG: Regular iterator cost: " + iterator.cost()); clauseIterators.add(iterator); } } @@ -201,6 +201,8 @@ public int docID() { // Create a simple bulk scorer that wraps the conjunction return new BulkScorer() { + private int totalCollected = 0; // Track total hits across all score() calls + @Override public int score(LeafCollector collector, Bits acceptDocs, int min, int max) throws IOException { @@ -218,18 +220,27 @@ public int score(LeafCollector collector, Bits acceptDocs, int min, int max) thr int collected = 0; int doc = -1; - // Score documents in the range [min, max) following Lucene's pattern - // Note: No threshold limit here - that's handled by individual ResumableDISI clauses + // Score documents in the range [min, max) with early termination for (doc = conjunctionDISI.docID(); doc < max; doc = conjunctionDISI.nextDoc()) { + // Early termination when we reach the threshold +// if (totalCollected >= threshold) { +//// System.out.println("DEBUG: Early termination at " + totalCollected + " hits (threshold: " + threshold + ")"); +// break; +// } + if (acceptDocs == null || acceptDocs.get(doc)) { collector.collect(doc); collected++; + totalCollected++; } } + + System.out.println("Total Collected: " + totalCollected + " Collected this window: " + collected); - System.out.println("Num conjunction hits "+collected); +// System.out.println("Num conjunction hits " + collected + " (total: " + totalCollected + ")"); // Return the current iterator position (standard Lucene pattern) + System.out.println("Conjunction DISI current position after bulkscorer.score: "+conjunctionDISI.docID()); return conjunctionDISI.docID(); } diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java index ccddd1b138478..c9a98c75d6747 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java @@ -46,6 +46,10 @@ * after {@code size} is hit */ public class ApproximatePointRangeQuery extends ApproximateQuery { + // Track total documents across all get() calls + private static int totalDocsAdded = 0; + private static int totalGetCalls = 0; + public static final Function LONG_FORMAT = bytes -> Long.toString(LongPoint.decodeDimension(bytes, 0)); public static final Function INT_FORMAT = bytes -> Integer.toString(IntPoint.decodeDimension(bytes, 0)); public static final Function HALF_FLOAT_FORMAT = bytes -> Float.toString(HalfFloatPoint.decodeDimension(bytes, 0)); @@ -174,15 +178,15 @@ public void grow(int count) { @Override public void visit(int docID) { - // it is possible that size < 1024 and docCount < size but we will continue to count through all the 1024 docs + // Log first docID if (docCount[0] == 0) { - System.out.println(docID); + System.out.println("First docID: " + docID); } adder.add(docID); docCount[0]++; - if (docCount[0] == 10241) { - System.out.println(docID); -// System.out.println(docCount[0]); + // Log when we hit certain milestones + if (docCount[0] >= 10240) { + System.out.println("Last docID at 10240: " + docID); } } @@ -193,18 +197,15 @@ public void visit(DocIdSetIterator iterator) throws IOException { @Override public void visit(IntsRef ref) { + // Log first docID from bulk visit if (docCount[0] == 0){ - System.out.println(ref.ints[0]); - } - if (docCount[0] == 10240) { - System.out.println(ref.ints[0]); -// System.out.println(docCount[0]); + System.out.println("First docID (bulk): " + ref.ints[0]); } adder.add(ref); docCount[0] += ref.length; - if (docCount[0] == 10240) { - System.out.println(ref.ints[0]); -// System.out.println(docCount[0]); + // Log last docID from bulk visit when we hit milestone + if (docCount[0] >= 10240) { + System.out.println("Last docID (bulk) at " + docCount[0] + ": " + ref.ints[ref.length - 1]); } } @@ -373,64 +374,55 @@ public void intersectLeft( public void intersectLeftIterative(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree, long[] docCount, ResumableDISI.BKDState state) throws IOException { while (true) { - System.out.println("Doc count: "+docCount[0]); - if (docCount[0] >= size) { - state.setPointTree(pointTree); - state.setInProgress(true); - return; - } - PointValues.Relation compare = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue()); if (compare == PointValues.Relation.CELL_INSIDE_QUERY) { - // This cell is fully inside the query shape: recursively add all points in this cell - // without filtering - pointTree.visitDocIDs(visitor); - if (docCount[0] >= size) { - state.setPointTree(pointTree); - state.setInProgress(true); - return; + // Check if processing this entire subtree would exceed our limit + long subtreeSize = pointTree.size(); + if (docCount[0] + subtreeSize > size) { + // Too big - need to process children individually + if (pointTree.moveToChild()) { + continue; // Process children one by one + } } + // Safe to process entire subtree + pointTree.visitDocIDs(visitor); +// if (docCount[0] >= size) { +// System.out.println("DEBUG: Saving state at node - min: " + java.util.Arrays.toString(pointTree.getMinPackedValue()) + +// ", max: " + java.util.Arrays.toString(pointTree.getMaxPackedValue())); +// state.setPointTree(pointTree); +// state.setInProgress(true); +// return; +// } } else if (compare == PointValues.Relation.CELL_CROSSES_QUERY) { // The cell crosses the shape boundary, or the cell fully contains the query, so we fall // through and do full filtering: if (pointTree.moveToChild()) { - if (docCount[0] >= size) { - state.setPointTree(pointTree); - state.setInProgress(true); - return; - } continue; } - if (docCount[0] >= size) { - state.setPointTree(pointTree); - state.setInProgress(true); - return; - } - // TODO: we can assert that the first value here in fact matches what the pointTree - // claimed? // Leaf node; scan and filter all points in this block: pointTree.visitDocValues(visitor); +// if (docCount[0] >= size) { +// System.out.println("DEBUG: Saving state at leaf node - min: " + java.util.Arrays.toString(pointTree.getMinPackedValue()) + +// ", max: " + java.util.Arrays.toString(pointTree.getMaxPackedValue())); +// state.setPointTree(pointTree); +// state.setInProgress(true); +// return; +// } } // position ourself to next place while (pointTree.moveToSibling() == false) { - if (docCount[0] >= size) { - state.setPointTree(pointTree); - state.setInProgress(true); - return; - } if (pointTree.moveToParent() == false) { - if (docCount[0] >= size) { - state.setPointTree(pointTree); - state.setInProgress(true); - return; - } + // Reached true root - entire BKD tree traversal is complete + state.setExhausted(true); return; } } - - - + if (docCount[0] >= size) { + state.setPointTree(pointTree); + state.setInProgress(true); + return; + } } } @@ -521,9 +513,9 @@ public Scorer get(long leadCost) throws IOException { state.setPointTree(values.getPointTree()); docCount[0] = 0; // Reset doc count for first call } else if (state.getPointTree() != null && !state.isExhausted()) { - // Resume from where we left off - but reset docCount for this expansion + // Resume from where we left off - reset docCount for this expansion System.out.println("DEBUG: Resuming from saved state"); - docCount[0] = 0; + docCount[0] = 0; // Reset doc count for each expansion } else { System.out.println("DEBUG: BKD state is exhausted or invalid"); } @@ -536,21 +528,27 @@ public Scorer get(long leadCost) throws IOException { // Use intersectLeftIterative for resumable traversal System.out.println("DEBUG: Starting intersectLeftIterative, current docCount: " + docCount[0]); - intersectLeftIterative(visitor, state.getPointTree(), docCount, state); + if (!context.isTopLevel) { + intersectLeftIterative(visitor, state.getPointTree(), docCount, state); + } else { + intersectLeft(visitor, state.getPointTree(), docCount, state, true); + } System.out.println("DEBUG: After intersectLeftIterative, docCount: " + docCount[0] + ", size: " + size); // Check if we collected enough documents if (docCount[0] >= size) { state.setInProgress(true); state.needMore = true; - } else { - // If we didn't reach the size limit, we've exhausted the tree - state.setExhausted(true); } + // Note: exhaustion is now handled inside intersectLeftIterative + + //System.out.println("DEBUG: BKD traversal completed, found " + docCount[0] + " documents"); - System.out.println("DEBUG: BKD traversal completed, found " + docCount[0] + " documents"); + // Track total documents added + totalDocsAdded += docCount[0]; + totalGetCalls++; + System.out.println("DEBUG: Total docs added across all calls: " + totalDocsAdded + " (call #" + totalGetCalls + ")"); - docCount[0] = 0; // If we didn't collect any documents and we're not in progress, we've exhausted the tree diff --git a/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java b/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java index f326478e910d8..5b480bde4b8cc 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java +++ b/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java @@ -60,7 +60,7 @@ public ResumableDISI(ScorerSupplier scorerSupplier, int expansionSize) { this.scorerSupplier = scorerSupplier; this.expansionSize = expansionSize; this.instanceId = ++instanceCounter; - System.out.println("DEBUG: Created ResumableDISI instance " + instanceId); +// System.out.println("DEBUG: Created ResumableDISI instance " + instanceId); } @Override @@ -79,9 +79,18 @@ public int nextDoc() throws IOException { if (!expandInternally()) { return NO_MORE_DOCS; } - // expandInternally() already positioned us on the first document - documentsReturned++; - return currentDocId; + // Position the new iterator on its first document + int doc = currentDisi.nextDoc(); + if (doc != NO_MORE_DOCS) { + currentDocId = doc; + documentsReturned++; + return doc; + } else { + // Iterator was empty after all + fullyExhausted = true; + currentDocId = NO_MORE_DOCS; + return NO_MORE_DOCS; + } } // Try to get the next document from current iterator @@ -95,14 +104,19 @@ public int nextDoc() throws IOException { // Current iterator exhausted, try to expand internally if (expandInternally()) { - // expandInternally() already positioned us on the first document of the new batch - documentsReturned++; - return currentDocId; + // Position the new iterator on its first document + doc = currentDisi.nextDoc(); + if (doc != NO_MORE_DOCS) { + currentDocId = doc; + documentsReturned++; + return doc; + } } // No more expansion possible // System.out.println("DEBUG: ResumableDISI " + instanceId + " - EXHAUSTED after returning " + documentsReturned + " total documents"); currentDocId = NO_MORE_DOCS; + fullyExhausted = true; return NO_MORE_DOCS; } @@ -115,25 +129,33 @@ public int advance(int target) throws IOException { return NO_MORE_DOCS; } + // If target is NO_MORE_DOCS, no point in expanding + if (target == NO_MORE_DOCS) { +// System.out.println("DEBUG: ResumableDISI " + instanceId + " - advance() target is NO_MORE_DOCS, marking as exhausted"); + fullyExhausted = true; + currentDocId = NO_MORE_DOCS; + return NO_MORE_DOCS; + } + // If we don't have a current iterator, get one if (currentDisi == null) { if (!expandInternally()) { // System.out.println("DEBUG: ResumableDISI " + instanceId + " - advance() expandInternally failed"); return NO_MORE_DOCS; } - // If the first document is >= target, we're good - if (currentDocId >= target) { -// System.out.println("DEBUG: ResumableDISI " + instanceId + " - advance() first doc " + currentDocId + " >= target " + target); - // Don't increment documentsReturned - it was already counted in expandInternally() - return currentDocId; - } - // Otherwise, advance to target - int doc = currentDisi.advance(target); + // Position the new iterator and check if it meets target + int doc = currentDisi.nextDoc(); if (doc != NO_MORE_DOCS) { currentDocId = doc; - // Don't increment documentsReturned here either - advance() skips documents, doesn't return them one by one -// System.out.println("DEBUG: ResumableDISI " + instanceId + " - advance() found doc " + doc + " >= target " + target); - return doc; + if (currentDocId >= target) { + return currentDocId; + } + // Otherwise, advance to target + doc = currentDisi.advance(target); + if (doc != NO_MORE_DOCS) { + currentDocId = doc; + return doc; + } } // Fall through to try expansion } else { @@ -150,21 +172,23 @@ public int advance(int target) throws IOException { } // Current iterator exhausted, try to expand internally - if (expandInternally()) { - // If the first document of new batch is >= target, we're good - if (currentDocId >= target) { - // Don't increment documentsReturned - it was already counted in expandInternally() -// System.out.println("DEBUG: ResumableDISI " + instanceId + " - advance() expanded, first doc " + currentDocId + " >= target " + target); - return currentDocId; - } - // Otherwise, advance to target - int doc = currentDisi.advance(target); + while (expandInternally()) { + // Position the new iterator and check if it meets target + int doc = currentDisi.nextDoc(); if (doc != NO_MORE_DOCS) { currentDocId = doc; - // Don't increment documentsReturned - advance() skips documents -// System.out.println("DEBUG: ResumableDISI " + instanceId + " - advance() expanded and found doc " + doc + " >= target " + target); - return doc; + if (currentDocId >= target) { + return currentDocId; + } + // Otherwise, advance to target + doc = currentDisi.advance(target); + if (doc != NO_MORE_DOCS) { + currentDocId = doc; + return doc; + } } + // This expansion didn't have a suitable document, try expanding again +// System.out.println("DEBUG: ResumableDISI " + instanceId + " - advance() expansion didn't find target " + target + ", trying next expansion"); } // No more expansion possible @@ -194,6 +218,7 @@ private boolean expandInternally() throws IOException { // } // Get a new scorer from the supplier - this will resume from saved BKD state + System.out.println("DEBUG: ResumableDISI " + instanceId + " - calling expandInternally"); Scorer scorer = scorerSupplier.get(scorerSupplier.cost()); if (scorer == null) { fullyExhausted = true; @@ -203,17 +228,17 @@ private boolean expandInternally() throws IOException { currentDisi = scorer.iterator(); documentsScored += expansionSize; // Track total documents scored - System.out.println("DEBUG: ResumableDISI " + instanceId + " - got iterator with " + currentDisi.cost() + " documents"); +// System.out.println("DEBUG: ResumableDISI " + instanceId + " - got iterator with " + currentDisi.cost() + " documents"); - // Check if the new iterator has any documents - int firstDoc = currentDisi.nextDoc(); - if (firstDoc == NO_MORE_DOCS) { + // Check if the iterator has any documents by looking at cost + if (currentDisi.cost() == 0) { + System.out.println("DEBUG: ResumableDISI " + instanceId + " - expandInternally got empty iterator (cost=0), marking as exhausted"); fullyExhausted = true; + currentDocId = NO_MORE_DOCS; return false; } - // Position the iterator on the first document - currentDocId = firstDoc; + // Don't position the iterator - let nextDoc() or advance() handle that return true; } From 58c139dcfe17ad342ca07f55da4689f83b6a54e5 Mon Sep 17 00:00:00 2001 From: Sawan Srivastava Date: Sat, 9 Aug 2025 15:55:59 +0000 Subject: [PATCH 24/38] Added early termination in bulk scorer Signed-off-by: Sawan Srivastava --- .../ApproximateBooleanScorerSupplier.java | 22 ++++++++++++++----- .../ApproximatePointRangeQuery.java | 22 +++++++++---------- .../search/approximate/ResumableDISI.java | 10 ++++++++- 3 files changed, 36 insertions(+), 18 deletions(-) diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java index a2a52fac0ad1c..01fe8969a5c3f 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java @@ -201,7 +201,8 @@ public int docID() { // Create a simple bulk scorer that wraps the conjunction return new BulkScorer() { - private int totalCollected = 0; // Track total hits across all score() calls + private int totalCollected = 0; + private boolean expansionStopped = false; // Track total hits across all score() calls @Override public int score(LeafCollector collector, Bits acceptDocs, int min, int max) throws IOException { @@ -223,10 +224,19 @@ public int score(LeafCollector collector, Bits acceptDocs, int min, int max) thr // Score documents in the range [min, max) with early termination for (doc = conjunctionDISI.docID(); doc < max; doc = conjunctionDISI.nextDoc()) { // Early termination when we reach the threshold -// if (totalCollected >= threshold) { -//// System.out.println("DEBUG: Early termination at " + totalCollected + " hits (threshold: " + threshold + ")"); -// break; -// } + if (totalCollected >= 10000) { + if (!expansionStopped) { + // Stop all ResumableDISI instances from expanding further + for (DocIdSetIterator iter : clauseIterators) { + if (iter instanceof ResumableDISI disi) { + disi.stopExpansion(); + } + } + expansionStopped = true; + System.out.println("DEBUG: Stopped expansion for all ResumableDISI at " + totalCollected + " hits"); + } + return DocIdSetIterator.NO_MORE_DOCS; // Exit the entire score method + } if (acceptDocs == null || acceptDocs.get(doc)) { collector.collect(doc); @@ -234,7 +244,7 @@ public int score(LeafCollector collector, Bits acceptDocs, int min, int max) thr totalCollected++; } } - + System.out.println("Total Collected: " + totalCollected + " Collected this window: " + collected); // System.out.println("Num conjunction hits " + collected + " (total: " + totalCollected + ")"); diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java index c9a98c75d6747..2a703b75dd060 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java @@ -180,13 +180,13 @@ public void grow(int count) { public void visit(int docID) { // Log first docID if (docCount[0] == 0) { - System.out.println("First docID: " + docID); +// System.out.println("First docID: " + docID); } adder.add(docID); docCount[0]++; // Log when we hit certain milestones if (docCount[0] >= 10240) { - System.out.println("Last docID at 10240: " + docID); +// System.out.println("Last docID at 10240: " + docID); } } @@ -199,13 +199,13 @@ public void visit(DocIdSetIterator iterator) throws IOException { public void visit(IntsRef ref) { // Log first docID from bulk visit if (docCount[0] == 0){ - System.out.println("First docID (bulk): " + ref.ints[0]); +// System.out.println("First docID (bulk): " + ref.ints[0]); } adder.add(ref); docCount[0] += ref.length; // Log last docID from bulk visit when we hit milestone if (docCount[0] >= 10240) { - System.out.println("Last docID (bulk) at " + docCount[0] + ": " + ref.ints[ref.length - 1]); +// System.out.println("Last docID (bulk) at " + docCount[0] + ": " + ref.ints[ref.length - 1]); } } @@ -378,13 +378,13 @@ public void intersectLeftIterative(PointValues.IntersectVisitor visitor, PointVa visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue()); if (compare == PointValues.Relation.CELL_INSIDE_QUERY) { // Check if processing this entire subtree would exceed our limit - long subtreeSize = pointTree.size(); - if (docCount[0] + subtreeSize > size) { - // Too big - need to process children individually - if (pointTree.moveToChild()) { - continue; // Process children one by one - } - } +// long subtreeSize = pointTree.size(); +// if (docCount[0] + subtreeSize > size) { +// // Too big - need to process children individually +// if (pointTree.moveToChild()) { +// continue; // Process children one by one +// } +// } // Safe to process entire subtree pointTree.visitDocIDs(visitor); // if (docCount[0] >= size) { diff --git a/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java b/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java index 5b480bde4b8cc..05c4ba3a5a78b 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java +++ b/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java @@ -33,6 +33,7 @@ public class ResumableDISI extends DocIdSetIterator { private DocIdSetIterator currentDisi; private int currentDocId = -1; private boolean fullyExhausted = false; + private volatile boolean stopExpansion = false; private int documentsScored = 0; // Total documents scored across all expansions private int documentsReturned = 0; // Count of documents returned by nextDoc() @@ -206,7 +207,7 @@ public int advance(int target) throws IOException { * @throws IOException If there's an error getting the scorer */ private boolean expandInternally() throws IOException { - if (fullyExhausted) { + if (fullyExhausted || stopExpansion) { return false; } @@ -247,6 +248,13 @@ public long cost() { return 10_000L; } + /** + * Signal this iterator to stop expanding and return NO_MORE_DOCS + */ + public void stopExpansion() { + stopExpansion = true; + } + /** * Returns whether this iterator has been fully exhausted. * From b2bf37800ba36edb519cac2e9bc3632a5baf9fa0 Mon Sep 17 00:00:00 2001 From: Sawan Srivastava Date: Tue, 12 Aug 2025 15:30:11 +0000 Subject: [PATCH 25/38] added new iterative approach + debugging Signed-off-by: Sawan Srivastava --- .../approximate/ApproximateBooleanQuery.java | 30 +-- .../ApproximateBooleanScorerSupplier.java | 82 +++++-- .../ApproximatePointRangeQuery.java | 227 ++++++++++-------- .../search/approximate/ResumableDISI.java | 74 +++--- 4 files changed, 253 insertions(+), 160 deletions(-) diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java index 716b196d60e13..e7441e628944e 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java @@ -79,21 +79,21 @@ protected boolean canApproximate(SearchContext context) { return false; } -// // For single clause boolean queries, check if the clause can be approximated -// if (clauses.size() == 1 && clauses.get(0).occur() != BooleanClause.Occur.MUST_NOT) { -// BooleanClause singleClause = clauses.get(0); -// Query clauseQuery = singleClause.query(); -// -// // If the clause is already an ApproximateScoreQuery, we can approximate + set context -// if (clauseQuery instanceof ApproximateScoreQuery approximateScoreQuery) { -// if (approximateScoreQuery.getApproximationQuery() instanceof ApproximateBooleanQuery nestedBool) { -// return nestedBool.canApproximate(context); -// } -// return approximateScoreQuery.getApproximationQuery().canApproximate(context); -// } -// } - -// return clauses.size() > 1 && clauses.stream().allMatch(clause -> clause.occur() == BooleanClause.Occur.FILTER); + // // For single clause boolean queries, check if the clause can be approximated + // if (clauses.size() == 1 && clauses.get(0).occur() != BooleanClause.Occur.MUST_NOT) { + // BooleanClause singleClause = clauses.get(0); + // Query clauseQuery = singleClause.query(); + // + // // If the clause is already an ApproximateScoreQuery, we can approximate + set context + // if (clauseQuery instanceof ApproximateScoreQuery approximateScoreQuery) { + // if (approximateScoreQuery.getApproximationQuery() instanceof ApproximateBooleanQuery nestedBool) { + // return nestedBool.canApproximate(context); + // } + // return approximateScoreQuery.getApproximationQuery().canApproximate(context); + // } + // } + + // return clauses.size() > 1 && clauses.stream().allMatch(clause -> clause.occur() == BooleanClause.Occur.FILTER); return clauses.stream().allMatch(clause -> clause.occur() == BooleanClause.Occur.FILTER); } diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java index 01fe8969a5c3f..1ef213509e7c3 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java @@ -88,6 +88,27 @@ public Scorer get(long leadCost) throws IOException { } } + // Debug: Print first 100 docIDs from each DISI + System.out.println("DEBUG: Printing first 100 docIDs from each DISI:"); + for (int i = 0; i < clauseIterators.size(); i++) { + DocIdSetIterator iter = clauseIterators.get(i); + System.out.print("DISI " + i + " first 100 docIDs: ["); + try { + for (int j = 0; j < 100; j++) { + int docId = iter.nextDoc(); + if (docId == DocIdSetIterator.NO_MORE_DOCS) { + System.out.print("NO_MORE_DOCS"); + break; + } + System.out.print(docId); + if (j < 99) System.out.print(", "); + } + System.out.println("]"); + } catch (IOException e) { + System.out.println("Error reading DISI " + i + ": " + e.getMessage() + "]"); + } + } + // Use Lucene's ConjunctionUtils to create the conjunction DocIdSetIterator conjunctionDISI = ConjunctionUtils.intersectIterators(clauseIterators); @@ -100,12 +121,12 @@ public DocIdSetIterator iterator() { @Override public float score() throws IOException { - return boost; + return 0.0f; } @Override public float getMaxScore(int upTo) throws IOException { - return boost; + return 0.0f; } @Override @@ -114,6 +135,7 @@ public int docID() { } }; } + /** * Get an estimate of the {@link Scorer} that would be returned by {@link #get}. */ @@ -152,28 +174,49 @@ public BulkScorer bulkScorer() throws IOException { // Create appropriate iterators for each clause - ResumableDISI only for approximatable queries List clauseIterators = new ArrayList<>(clauseWeights.size()); -// System.out.println("DEBUG: Creating iterators for " + clauseWeights.size() + " clauses"); + // System.out.println("DEBUG: Creating iterators for " + clauseWeights.size() + " clauses"); for (Weight weight : clauseWeights) { Query query = weight.getQuery(); ScorerSupplier supplier = weight.scorerSupplier(context); -// System.out.println("DEBUG: Processing query: " + query.getClass().getSimpleName() + " - " + query); + // System.out.println("DEBUG: Processing query: " + query.getClass().getSimpleName() + " - " + query); if (query instanceof ApproximateQuery) { // Use ResumableDISI for approximatable queries -// System.out.println("DEBUG: Using ResumableDISI for ApproximateQuery"); + // System.out.println("DEBUG: Using ResumableDISI for ApproximateQuery"); ResumableDISI disi = new ResumableDISI(supplier); clauseIterators.add(disi); } else { // Use regular DocIdSetIterator for non-approximatable queries -// System.out.println("DEBUG: Using regular DISI for non-approximatable query"); + // System.out.println("DEBUG: Using regular DISI for non-approximatable query"); Scorer scorer = supplier.get(supplier.cost()); DocIdSetIterator iterator = scorer.iterator(); -// System.out.println("DEBUG: Regular iterator cost: " + iterator.cost()); + // System.out.println("DEBUG: Regular iterator cost: " + iterator.cost()); clauseIterators.add(iterator); } } + // // Debug: Print first 100 docIDs from each DISI + System.out.println("DEBUG: Printing first 100 docIDs from each DISI:"); + for (int i = 0; i < clauseIterators.size(); i++) { + DocIdSetIterator iter = clauseIterators.get(i); + System.out.print("DISI " + i + " first 100 docIDs: ["); + try { + for (int j = 0; j < 10000; j++) { + int docId = iter.nextDoc(); + if (docId == DocIdSetIterator.NO_MORE_DOCS) { + System.out.print("NO_MORE_DOCS"); + break; + } + System.out.print(docId); + if (j < 9999) System.out.print(", "); + } + System.out.println("]"); + } catch (IOException e) { + System.out.println("Error reading DISI " + i + ": " + e.getMessage() + "]"); + } + } + // Use Lucene's ConjunctionUtils to create the conjunction ONCE (outside the BulkScorer) final DocIdSetIterator conjunctionDISI = ConjunctionUtils.intersectIterators(clauseIterators); // Create a simple scorer for the collector @@ -202,19 +245,20 @@ public int docID() { // Create a simple bulk scorer that wraps the conjunction return new BulkScorer() { private int totalCollected = 0; - private boolean expansionStopped = false; // Track total hits across all score() calls + private boolean expansionStopped = false; + private final List conjunctionDocIds = new ArrayList<>(); // Track total hits across all score() calls @Override public int score(LeafCollector collector, Bits acceptDocs, int min, int max) throws IOException { + System.out.println("bulkscorer.score called with min: " + min + " and max: " + max); collector.setScorer(scorer); // Position the iterator correctly (following Lucene's DefaultBulkScorer pattern) if (conjunctionDISI.docID() < min) { if (conjunctionDISI.docID() == min - 1) { conjunctionDISI.nextDoc(); - } - else { + } else { conjunctionDISI.advance(min); } } @@ -234,6 +278,8 @@ public int score(LeafCollector collector, Bits acceptDocs, int min, int max) thr } expansionStopped = true; System.out.println("DEBUG: Stopped expansion for all ResumableDISI at " + totalCollected + " hits"); + System.out.println("DEBUG: Conjunction docIDs: " + conjunctionDocIds); + } return DocIdSetIterator.NO_MORE_DOCS; // Exit the entire score method } @@ -242,19 +288,25 @@ public int score(LeafCollector collector, Bits acceptDocs, int min, int max) thr collector.collect(doc); collected++; totalCollected++; + conjunctionDocIds.add(doc); } } System.out.println("Total Collected: " + totalCollected + " Collected this window: " + collected); -// System.out.println("Num conjunction hits " + collected + " (total: " + totalCollected + ")"); + // Check if conjunction exhausted + if (conjunctionDISI.docID() == DocIdSetIterator.NO_MORE_DOCS) { + System.out.println("DEBUG: Conjunction exhausted at " + totalCollected + " total hits"); + System.out.println("DEBUG: Conjunction docIDs: " + conjunctionDocIds); + } + + // System.out.println("Num conjunction hits " + collected + " (total: " + totalCollected + ")"); // Return the current iterator position (standard Lucene pattern) - System.out.println("Conjunction DISI current position after bulkscorer.score: "+conjunctionDISI.docID()); + System.out.println("Conjunction DISI current position after bulkscorer.score: " + conjunctionDISI.docID()); return conjunctionDISI.docID(); } - @Override public long cost() { return ApproximateBooleanScorerSupplier.this.cost(); @@ -263,7 +315,3 @@ public long cost() { } } - - - - diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java index 2a703b75dd060..699209309e48c 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java @@ -38,7 +38,10 @@ import org.opensearch.search.sort.SortOrder; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import java.util.Objects; +import java.util.Stack; import java.util.function.Function; /** @@ -50,6 +53,10 @@ public class ApproximatePointRangeQuery extends ApproximateQuery { private static int totalDocsAdded = 0; private static int totalGetCalls = 0; + // Store first 10k docIDs for validation + private static final List firstDocIds = new ArrayList<>(); + private static boolean docIdsCollected = false; + public static final Function LONG_FORMAT = bytes -> Long.toString(LongPoint.decodeDimension(bytes, 0)); public static final Function INT_FORMAT = bytes -> Integer.toString(IntPoint.decodeDimension(bytes, 0)); public static final Function HALF_FLOAT_FORMAT = bytes -> Float.toString(HalfFloatPoint.decodeDimension(bytes, 0)); @@ -175,18 +182,19 @@ public void grow(int count) { adder = result.grow(count); } - @Override public void visit(int docID) { // Log first docID if (docCount[0] == 0) { -// System.out.println("First docID: " + docID); + System.out.println("First docID: " + docID); } + // firstDocIds.add(docID); + adder.add(docID); docCount[0]++; // Log when we hit certain milestones - if (docCount[0] >= 10240) { -// System.out.println("Last docID at 10240: " + docID); + if (docCount[0] >= 10200) { + System.out.println("Last docID at 10240: " + docID); } } @@ -198,14 +206,21 @@ public void visit(DocIdSetIterator iterator) throws IOException { @Override public void visit(IntsRef ref) { // Log first docID from bulk visit - if (docCount[0] == 0){ -// System.out.println("First docID (bulk): " + ref.ints[0]); + if (docCount[0] == 0) { + System.out.println("First docID (bulk): " + ref.ints[0]); } + // + // // Collect first 10240 docIDs for validation + + // for (int i = 0; i < ref.length; i++) { + // firstDocIds.add(ref.ints[i]); + // } + adder.add(ref); docCount[0] += ref.length; // Log last docID from bulk visit when we hit milestone if (docCount[0] >= 10240) { -// System.out.println("Last docID (bulk) at " + docCount[0] + ": " + ref.ints[ref.length - 1]); + // System.out.println("Last docID (bulk) at " + docCount[0] + ": " + ref.ints[ref.length - 1]); } } @@ -260,13 +275,9 @@ private boolean checkValidPointValues(PointValues values) throws IOException { return true; } - private void intersectLeft( - PointValues.PointTree pointTree, - PointValues.IntersectVisitor visitor, - long[] docCount, - ResumableDISI.BKDState state - ) throws IOException { - intersectLeft(visitor, pointTree, docCount, state, true); // Top-level call + private void intersectLeft(PointValues.PointTree pointTree, PointValues.IntersectVisitor visitor, long[] docCount) + throws IOException { + intersectLeft(visitor, pointTree, docCount); // Top-level call // Only assert for complete traversals (top-level calls) assert pointTree.moveToParent() == false; } @@ -278,32 +289,15 @@ private void intersectRight(PointValues.PointTree pointTree, PointValues.Interse } // custom intersect visitor to walk the left of the tree - public void intersectLeft( - PointValues.IntersectVisitor visitor, - PointValues.PointTree pointTree, - long[] docCount, - ResumableDISI.BKDState bkdState, - boolean isTopLevel - ) throws IOException { + public void intersectLeft(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree, long[] docCount) + throws IOException { if (docCount[0] >= size) { - // Save current position for resumption - if (bkdState != null) { - bkdState.needMore = true; - bkdState.setPointTree(pointTree); - bkdState.setInProgress(true); - } return; } PointValues.Relation r = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue()); if (r == PointValues.Relation.CELL_OUTSIDE_QUERY) { -// if (bkdState != null) { -// bkdState.needMore = true; -// bkdState.setCurrentTree(pointTree); -// bkdState.setInProgress(true); -// } return; } - // Handle leaf nodes if (pointTree.moveToChild() == false) { if (r == PointValues.Relation.CELL_INSIDE_QUERY) { @@ -312,17 +306,6 @@ public void intersectLeft( // CELL_CROSSES_QUERY pointTree.visitDocValues(visitor); } - - // After processing leaf, check if we need to save state for resumption - if (docCount[0] >= size) { - // We have enough documents - save state and signal need for more - if (bkdState != null) { - bkdState.setPointTree(pointTree); - bkdState.setInProgress(true); - bkdState.needMore = true; - } - } - return; } // For CELL_INSIDE_QUERY, check if we can skip right child @@ -332,18 +315,8 @@ public void intersectLeft( if (leftSize >= needed) { // Process only left child - intersectLeft(visitor, pointTree, docCount, bkdState, isTopLevel); // Pass through isTopLevel - if (docCount[0] >= size) { - // We have enough documents - save state and signal need for more - if (bkdState != null) { - bkdState.setPointTree(pointTree); - bkdState.setInProgress(true); - bkdState.needMore = true; - } - } - if (isTopLevel) { - pointTree.moveToParent(); - } + intersectLeft(visitor, pointTree, docCount); + pointTree.moveToParent(); return; } } @@ -355,45 +328,33 @@ public void intersectLeft( pointTree.moveToChild(); } // Process both children: left first, then right if needed - intersectLeft(visitor, pointTree, docCount, bkdState, isTopLevel); // Pass through isTopLevel - if (docCount[0] >= size) { - // We have enough documents - save state and signal need for more - if (bkdState != null) { - bkdState.setPointTree(pointTree); - bkdState.setInProgress(true); - bkdState.needMore = true; - } - } + intersectLeft(visitor, pointTree, docCount); if (docCount[0] < size && rightChild != null) { - intersectLeft(visitor, rightChild, docCount, bkdState, isTopLevel); // Pass through isTopLevel + intersectLeft(visitor, rightChild, docCount); } - // Only call moveToParent() for top-level calls - pointTree.moveToParent(); + pointTree.moveToParent(); } - public void intersectLeftIterative(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree, long[] docCount, ResumableDISI.BKDState state) throws IOException { + public void intersectLeftIterative( + PointValues.IntersectVisitor visitor, + PointValues.PointTree pointTree, + long[] docCount, + ResumableDISI.BKDState state + ) throws IOException { while (true) { - PointValues.Relation compare = - visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue()); + PointValues.Relation compare = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue()); if (compare == PointValues.Relation.CELL_INSIDE_QUERY) { // Check if processing this entire subtree would exceed our limit -// long subtreeSize = pointTree.size(); -// if (docCount[0] + subtreeSize > size) { -// // Too big - need to process children individually -// if (pointTree.moveToChild()) { -// continue; // Process children one by one -// } -// } + long subtreeSize = pointTree.size(); + if (docCount[0] + subtreeSize > size) { + // Too big - need to process children individually + if (pointTree.moveToChild()) { + continue; // Process children one by one + } + } // Safe to process entire subtree pointTree.visitDocIDs(visitor); -// if (docCount[0] >= size) { -// System.out.println("DEBUG: Saving state at node - min: " + java.util.Arrays.toString(pointTree.getMinPackedValue()) + -// ", max: " + java.util.Arrays.toString(pointTree.getMaxPackedValue())); -// state.setPointTree(pointTree); -// state.setInProgress(true); -// return; -// } } else if (compare == PointValues.Relation.CELL_CROSSES_QUERY) { // The cell crosses the shape boundary, or the cell fully contains the query, so we fall // through and do full filtering: @@ -402,13 +363,6 @@ public void intersectLeftIterative(PointValues.IntersectVisitor visitor, PointVa } // Leaf node; scan and filter all points in this block: pointTree.visitDocValues(visitor); -// if (docCount[0] >= size) { -// System.out.println("DEBUG: Saving state at leaf node - min: " + java.util.Arrays.toString(pointTree.getMinPackedValue()) + -// ", max: " + java.util.Arrays.toString(pointTree.getMaxPackedValue())); -// state.setPointTree(pointTree); -// state.setInProgress(true); -// return; -// } } // position ourself to next place while (pointTree.moveToSibling() == false) { @@ -426,6 +380,75 @@ public void intersectLeftIterative(PointValues.IntersectVisitor visitor, PointVa } } + public void intersectLeftIterativeNew( + PointValues.IntersectVisitor visitor, + PointValues.PointTree pointTree, + long[] docCount, + ResumableDISI.BKDState state + ) throws IOException { + + // Stack to track nodes to process + Stack nodeStack = new Stack<>(); + nodeStack.push(pointTree.clone()); + + while (!nodeStack.isEmpty() && docCount[0] < size) { + PointValues.PointTree currentTree = nodeStack.pop(); + + if (docCount[0] >= size) { + continue; + } + + PointValues.Relation r = visitor.compare(currentTree.getMinPackedValue(), currentTree.getMaxPackedValue()); + if (r == PointValues.Relation.CELL_OUTSIDE_QUERY) { + continue; + } + + // Handle leaf nodes + if (currentTree.moveToChild() == false) { + if (r == PointValues.Relation.CELL_INSIDE_QUERY) { + currentTree.visitDocIDs(visitor); + } else { + // CELL_CROSSES_QUERY + currentTree.visitDocValues(visitor); + } + continue; + } + + // Internal node processing + PointValues.PointTree leftChild = currentTree.clone(); + PointValues.PointTree rightChild = null; + + // Check if right sibling exists + if (currentTree.moveToSibling()) { + rightChild = currentTree.clone(); + } + + // For CELL_INSIDE_QUERY, check if we can skip right child + if (r == PointValues.Relation.CELL_INSIDE_QUERY && rightChild != null) { + long leftSize = leftChild.size(); + long needed = size - docCount[0]; + + if (leftSize >= needed) { + // Process only left child + nodeStack.push(leftChild); + continue; + } + } + + // Process both children: push right first (so left is processed first due to stack LIFO) + if (rightChild != null) { + nodeStack.push(rightChild); + } + nodeStack.push(leftChild); + if (docCount[0] >= size) { + state.setPointTree(currentTree); + state.setInProgress(true); + return; + } + } + + } + // custom intersect visitor to walk the right of tree (from rightmost leaf going left) public void intersectRight(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree, long[] docCount) throws IOException { @@ -493,18 +516,17 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti PointValues.IntersectVisitor visitor = getIntersectVisitor(result, docCount); long cost = -1; - @Override public Scorer get(long leadCost) throws IOException { System.out.println("DEBUG: ApproximatePointRangeQuery.get() called - BKD state management disabled"); -// // Create fresh DocIdSetBuilder and visitor for each call + // // Create fresh DocIdSetBuilder and visitor for each call result = new DocIdSetBuilder(reader.maxDoc(), values); visitor = getIntersectVisitor(result, docCount); // Simple approach: always traverse from root System.out.println("DEBUG: Starting BKD traversal from root, target size: " + size); -// values.intersect(visitor); + // values.intersect(visitor); // Check if we have a saved tree and we're not exhausted if (state.getPointTree() == null && !state.isExhausted()) { @@ -529,9 +551,10 @@ public Scorer get(long leadCost) throws IOException { // Use intersectLeftIterative for resumable traversal System.out.println("DEBUG: Starting intersectLeftIterative, current docCount: " + docCount[0]); if (!context.isTopLevel) { - intersectLeftIterative(visitor, state.getPointTree(), docCount, state); + intersectLeftIterativeNew(visitor, state.getPointTree(), docCount, state); } else { - intersectLeft(visitor, state.getPointTree(), docCount, state, true); + intersectLeft(visitor, state.getPointTree(), docCount); + // values.intersect(visitor); } System.out.println("DEBUG: After intersectLeftIterative, docCount: " + docCount[0] + ", size: " + size); @@ -542,13 +565,14 @@ public Scorer get(long leadCost) throws IOException { } // Note: exhaustion is now handled inside intersectLeftIterative - //System.out.println("DEBUG: BKD traversal completed, found " + docCount[0] + " documents"); + // System.out.println("DEBUG: BKD traversal completed, found " + docCount[0] + " documents"); // Track total documents added totalDocsAdded += docCount[0]; totalGetCalls++; - System.out.println("DEBUG: Total docs added across all calls: " + totalDocsAdded + " (call #" + totalGetCalls + ")"); - + System.out.println( + "DEBUG: Total docs added across all calls: " + totalDocsAdded + " (call #" + totalGetCalls + ")" + ); // If we didn't collect any documents and we're not in progress, we've exhausted the tree @@ -558,6 +582,7 @@ public Scorer get(long leadCost) throws IOException { DocIdSetIterator iterator = result.build().iterator(); result = null; System.out.println("DEBUG: Built iterator with cost: " + iterator.cost()); + // System.out.println("DocIDs collected: "+firstDocIds); return new ConstantScoreScorer(score(), scoreMode, iterator); } @@ -630,6 +655,8 @@ private byte[] computeEffectiveBound(SearchContext context, boolean isLowerBound @Override public boolean canApproximate(SearchContext context) { + // System.out.println("canApproximate: false"); + // return false; if (context == null) { return false; } diff --git a/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java b/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java index 05c4ba3a5a78b..6265178871174 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java +++ b/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java @@ -61,7 +61,12 @@ public ResumableDISI(ScorerSupplier scorerSupplier, int expansionSize) { this.scorerSupplier = scorerSupplier; this.expansionSize = expansionSize; this.instanceId = ++instanceCounter; -// System.out.println("DEBUG: Created ResumableDISI instance " + instanceId); + try { + expandInternally(); + } catch (IOException e) { + throw new RuntimeException(e); + } + // System.out.println("DEBUG: Created ResumableDISI instance " + instanceId); } @Override @@ -115,7 +120,8 @@ public int nextDoc() throws IOException { } // No more expansion possible -// System.out.println("DEBUG: ResumableDISI " + instanceId + " - EXHAUSTED after returning " + documentsReturned + " total documents"); + // System.out.println("DEBUG: ResumableDISI " + instanceId + " - EXHAUSTED after returning " + documentsReturned + " total + // documents"); currentDocId = NO_MORE_DOCS; fullyExhausted = true; return NO_MORE_DOCS; @@ -123,16 +129,16 @@ public int nextDoc() throws IOException { @Override public int advance(int target) throws IOException { -// System.out.println("DEBUG: ResumableDISI " + instanceId + " - advance(" + target + ") called, currentDocId: " + currentDocId); + // System.out.println("DEBUG: ResumableDISI " + instanceId + " - advance(" + target + ") called, currentDocId: " + currentDocId); if (fullyExhausted) { -// System.out.println("DEBUG: ResumableDISI " + instanceId + " - advance() called but already exhausted"); + System.out.println("DEBUG: ResumableDISI " + instanceId + " - advance() called but already exhausted"); return NO_MORE_DOCS; } // If target is NO_MORE_DOCS, no point in expanding if (target == NO_MORE_DOCS) { -// System.out.println("DEBUG: ResumableDISI " + instanceId + " - advance() target is NO_MORE_DOCS, marking as exhausted"); + System.out.println("DEBUG: ResumableDISI " + instanceId + " - advance() target is NO_MORE_DOCS, marking as exhausted"); fullyExhausted = true; currentDocId = NO_MORE_DOCS; return NO_MORE_DOCS; @@ -141,7 +147,7 @@ public int advance(int target) throws IOException { // If we don't have a current iterator, get one if (currentDisi == null) { if (!expandInternally()) { -// System.out.println("DEBUG: ResumableDISI " + instanceId + " - advance() expandInternally failed"); + // System.out.println("DEBUG: ResumableDISI " + instanceId + " - advance() expandInternally failed"); return NO_MORE_DOCS; } // Position the new iterator and check if it meets target @@ -165,35 +171,45 @@ public int advance(int target) throws IOException { if (doc != NO_MORE_DOCS) { currentDocId = doc; // Don't increment documentsReturned - advance() skips documents, doesn't return them one by one -// System.out.println("DEBUG: ResumableDISI " + instanceId + " - advance() existing iterator found doc " + doc + " >= target " + target); + // System.out.println("DEBUG: ResumableDISI " + instanceId + " - advance() existing iterator found doc " + doc + " >= target + // " + target); return doc; } // Current iterator exhausted, try to expand -// System.out.println("DEBUG: ResumableDISI " + instanceId + " - advance() current iterator exhausted, trying to expand"); + // System.out.println("DEBUG: ResumableDISI " + instanceId + " - advance() current iterator exhausted, trying to expand"); } // Current iterator exhausted, try to expand internally while (expandInternally()) { // Position the new iterator and check if it meets target - int doc = currentDisi.nextDoc(); + // Otherwise, advance to target + int doc = currentDisi.advance(target); if (doc != NO_MORE_DOCS) { currentDocId = doc; - if (currentDocId >= target) { - return currentDocId; - } - // Otherwise, advance to target - doc = currentDisi.advance(target); - if (doc != NO_MORE_DOCS) { - currentDocId = doc; - return doc; - } + return doc; } + // int doc = currentDisi.nextDoc(); + // if (doc != NO_MORE_DOCS) { + // currentDocId = doc; + // if (currentDocId >= target) { + // return currentDocId; + // } + // // Otherwise, advance to target + // doc = currentDisi.advance(target); + // if (doc != NO_MORE_DOCS) { + // currentDocId = doc; + // return doc; + // } + // } // This expansion didn't have a suitable document, try expanding again -// System.out.println("DEBUG: ResumableDISI " + instanceId + " - advance() expansion didn't find target " + target + ", trying next expansion"); + // System.out.println("DEBUG: ResumableDISI " + instanceId + " - advance() expansion didn't find target " + target + ", trying + // next expansion"); } // No more expansion possible - System.out.println("DEBUG: ResumableDISI " + instanceId + " - advance() EXHAUSTED after returning " + documentsReturned + " total documents"); + System.out.println( + "DEBUG: ResumableDISI " + instanceId + " - advance() EXHAUSTED after returning " + documentsReturned + " total documents" + ); currentDocId = NO_MORE_DOCS; fullyExhausted = true; return NO_MORE_DOCS; @@ -211,12 +227,12 @@ private boolean expandInternally() throws IOException { return false; } -// // For now, disable expansion after first call to test basic logic -// if (currentDisi != null) { -// System.out.println("DEBUG: ResumableDISI " + instanceId + " - expansion disabled, but NOT marking as exhausted"); -// // Don't set fullyExhausted = true here! Let the iterator continue with its current batch -// return false; -// } + // // For now, disable expansion after first call to test basic logic + // if (currentDisi != null) { + // System.out.println("DEBUG: ResumableDISI " + instanceId + " - expansion disabled, but NOT marking as exhausted"); + // // Don't set fullyExhausted = true here! Let the iterator continue with its current batch + // return false; + // } // Get a new scorer from the supplier - this will resume from saved BKD state System.out.println("DEBUG: ResumableDISI " + instanceId + " - calling expandInternally"); @@ -229,11 +245,13 @@ private boolean expandInternally() throws IOException { currentDisi = scorer.iterator(); documentsScored += expansionSize; // Track total documents scored -// System.out.println("DEBUG: ResumableDISI " + instanceId + " - got iterator with " + currentDisi.cost() + " documents"); + // System.out.println("DEBUG: ResumableDISI " + instanceId + " - got iterator with " + currentDisi.cost() + " documents"); // Check if the iterator has any documents by looking at cost if (currentDisi.cost() == 0) { - System.out.println("DEBUG: ResumableDISI " + instanceId + " - expandInternally got empty iterator (cost=0), marking as exhausted"); + System.out.println( + "DEBUG: ResumableDISI " + instanceId + " - expandInternally got empty iterator (cost=0), marking as exhausted" + ); fullyExhausted = true; currentDocId = NO_MORE_DOCS; return false; From 5d984261becbc37d1ec0638fe9e7dbac67de6423 Mon Sep 17 00:00:00 2001 From: Sawan Srivastava Date: Tue, 12 Aug 2025 18:03:16 +0000 Subject: [PATCH 26/38] added bulk window scoring approach Signed-off-by: Sawan Srivastava --- .../ApproximateBooleanScorerSupplier.java | 216 +++++++++--------- .../ApproximatePointRangeQuery.java | 110 +++------ 2 files changed, 142 insertions(+), 184 deletions(-) diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java index 1ef213509e7c3..0a9a9a22d921e 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java @@ -22,6 +22,7 @@ */ public class ApproximateBooleanScorerSupplier extends ScorerSupplier { private final List clauseWeights; + private final List cachedSuppliers; // Cache suppliers to avoid repeated calls private final ScoreMode scoreMode; private final float boost; private final int threshold; @@ -46,16 +47,18 @@ public ApproximateBooleanScorerSupplier( LeafReaderContext context ) throws IOException { this.clauseWeights = new ArrayList<>(); + this.cachedSuppliers = new ArrayList<>(); this.scoreMode = scoreMode; this.boost = boost; this.threshold = threshold; this.context = context; - // Store weights that have valid scorer suppliers + // Store weights and cache their suppliers for (Weight clauseWeight : clauseWeights) { ScorerSupplier supplier = clauseWeight.scorerSupplier(context); if (supplier != null) { this.clauseWeights.add(clauseWeight); + this.cachedSuppliers.add(supplier); } } } @@ -73,9 +76,10 @@ public Scorer get(long leadCost) throws IOException { // Create appropriate iterators for each clause - ResumableDISI only for approximatable queries List clauseIterators = new ArrayList<>(clauseWeights.size()); - for (Weight weight : clauseWeights) { + for (int i = 0; i < clauseWeights.size(); i++) { + Weight weight = clauseWeights.get(i); + ScorerSupplier supplier = cachedSuppliers.get(i); // Use cached supplier Query query = weight.getQuery(); - ScorerSupplier supplier = weight.scorerSupplier(context); if (query instanceof ApproximateQuery) { // Use ResumableDISI for approximatable queries @@ -88,27 +92,6 @@ public Scorer get(long leadCost) throws IOException { } } - // Debug: Print first 100 docIDs from each DISI - System.out.println("DEBUG: Printing first 100 docIDs from each DISI:"); - for (int i = 0; i < clauseIterators.size(); i++) { - DocIdSetIterator iter = clauseIterators.get(i); - System.out.print("DISI " + i + " first 100 docIDs: ["); - try { - for (int j = 0; j < 100; j++) { - int docId = iter.nextDoc(); - if (docId == DocIdSetIterator.NO_MORE_DOCS) { - System.out.print("NO_MORE_DOCS"); - break; - } - System.out.print(docId); - if (j < 99) System.out.print(", "); - } - System.out.println("]"); - } catch (IOException e) { - System.out.println("Error reading DISI " + i + ": " + e.getMessage() + "]"); - } - } - // Use Lucene's ConjunctionUtils to create the conjunction DocIdSetIterator conjunctionDISI = ConjunctionUtils.intersectIterators(clauseIterators); @@ -143,18 +126,10 @@ public int docID() { public long cost() { if (cost == -1) { // Estimate cost as the minimum cost of all clauses (conjunction) - if (!clauseWeights.isEmpty()) { + if (!cachedSuppliers.isEmpty()) { cost = Long.MAX_VALUE; - for (Weight weight : clauseWeights) { - try { - ScorerSupplier supplier = weight.scorerSupplier(context); - if (supplier != null) { - cost = Math.min(cost, supplier.cost()); - } - } catch (IOException e) { - // If we can't get the cost, use a default - cost = Math.min(cost, 1000); - } + for (ScorerSupplier supplier : cachedSuppliers) { + cost = Math.min(cost, supplier.cost()); } } else { cost = 0; @@ -172,58 +147,23 @@ public BulkScorer bulkScorer() throws IOException { return null; } - // Create appropriate iterators for each clause - ResumableDISI only for approximatable queries - List clauseIterators = new ArrayList<>(clauseWeights.size()); - // System.out.println("DEBUG: Creating iterators for " + clauseWeights.size() + " clauses"); - - for (Weight weight : clauseWeights) { - Query query = weight.getQuery(); - ScorerSupplier supplier = weight.scorerSupplier(context); - // System.out.println("DEBUG: Processing query: " + query.getClass().getSimpleName() + " - " + query); - - if (query instanceof ApproximateQuery) { - // Use ResumableDISI for approximatable queries - // System.out.println("DEBUG: Using ResumableDISI for ApproximateQuery"); - ResumableDISI disi = new ResumableDISI(supplier); - clauseIterators.add(disi); - } else { - // Use regular DocIdSetIterator for non-approximatable queries - // System.out.println("DEBUG: Using regular DISI for non-approximatable query"); - Scorer scorer = supplier.get(supplier.cost()); - DocIdSetIterator iterator = scorer.iterator(); - // System.out.println("DEBUG: Regular iterator cost: " + iterator.cost()); - clauseIterators.add(iterator); - } + // Calculate window size heuristic using cached suppliers + long minCost = Long.MAX_VALUE; + long maxCost = 0; + for (ScorerSupplier supplier : cachedSuppliers) { + long cost = supplier.cost(); + minCost = Math.min(minCost, cost); + maxCost = Math.max(maxCost, cost); } + final int initialWindowSize = (int) Math.min(minCost, maxCost >> 7); // max(costs)/2^7 + System.out.println("DEBUG: Window heuristic - minCost: " + minCost + ", maxCost: " + maxCost + ", initialWindowSize: " + initialWindowSize); - // // Debug: Print first 100 docIDs from each DISI - System.out.println("DEBUG: Printing first 100 docIDs from each DISI:"); - for (int i = 0; i < clauseIterators.size(); i++) { - DocIdSetIterator iter = clauseIterators.get(i); - System.out.print("DISI " + i + " first 100 docIDs: ["); - try { - for (int j = 0; j < 10000; j++) { - int docId = iter.nextDoc(); - if (docId == DocIdSetIterator.NO_MORE_DOCS) { - System.out.print("NO_MORE_DOCS"); - break; - } - System.out.print(docId); - if (j < 9999) System.out.print(", "); - } - System.out.println("]"); - } catch (IOException e) { - System.out.println("Error reading DISI " + i + ": " + e.getMessage() + "]"); - } - } - - // Use Lucene's ConjunctionUtils to create the conjunction ONCE (outside the BulkScorer) - final DocIdSetIterator conjunctionDISI = ConjunctionUtils.intersectIterators(clauseIterators); - // Create a simple scorer for the collector + // Create a simple scorer for the collector (will be used by windowed approach) Scorer scorer = new Scorer() { @Override public DocIdSetIterator iterator() { - return conjunctionDISI; + // This won't be used in windowed approach + return DocIdSetIterator.empty(); } @Override @@ -238,7 +178,7 @@ public float getMaxScore(int upTo) throws IOException { @Override public int docID() { - return conjunctionDISI.docID(); + return -1; } }; @@ -248,40 +188,95 @@ public int docID() { private boolean expansionStopped = false; private final List conjunctionDocIds = new ArrayList<>(); // Track total hits across all score() calls + // Windowed approach state + private int currentWindowSize = initialWindowSize; + private DocIdSetIterator globalConjunction = null; + + private List rebuildIteratorsWithWindowSize(int windowSize) throws IOException { + List newIterators = new ArrayList<>(); + for (int i = 0; i < clauseWeights.size(); i++) { + Weight weight = clauseWeights.get(i); + ScorerSupplier supplier = cachedSuppliers.get(i); // Use cached supplier + Query query = weight.getQuery(); + + if (query instanceof ApproximateQuery) { + // For approximatable queries, try to use the window size + if (query instanceof ApproximatePointRangeQuery) { + ApproximatePointRangeQuery approxQuery = (ApproximatePointRangeQuery) query; + // Temporarily set the size + int originalSize = approxQuery.getSize(); + approxQuery.setSize(windowSize); + try { + Scorer scorer = supplier.get(windowSize); + newIterators.add(scorer.iterator()); + } finally { + // Restore original size + approxQuery.setSize(originalSize); + } + } else { + // Other approximate queries - use ResumableDISI + ResumableDISI disi = new ResumableDISI(supplier); + newIterators.add(disi); + } + } else { + // Regular queries use full cost + Scorer scorer = supplier.get(supplier.cost()); + newIterators.add(scorer.iterator()); + } + } + return newIterators; + } + @Override public int score(LeafCollector collector, Bits acceptDocs, int min, int max) throws IOException { - System.out.println("bulkscorer.score called with min: " + min + " and max: " + max); collector.setScorer(scorer); - - // Position the iterator correctly (following Lucene's DefaultBulkScorer pattern) - if (conjunctionDISI.docID() < min) { - if (conjunctionDISI.docID() == min - 1) { - conjunctionDISI.nextDoc(); + + // Check if we need to expand window + if (totalCollected < 10000 && + (globalConjunction == null || globalConjunction.docID() == DocIdSetIterator.NO_MORE_DOCS)) { + + System.out.println("DEBUG: Expanding window from " + currentWindowSize + " to " + (currentWindowSize * 3)); + currentWindowSize *= 3; + + // Rebuild iterators with new window size + List newIterators = rebuildIteratorsWithWindowSize(currentWindowSize); + globalConjunction = ConjunctionUtils.intersectIterators(newIterators); + + // Return first docID from new conjunction (could be < min) + int firstDoc = globalConjunction.nextDoc(); + if (firstDoc != DocIdSetIterator.NO_MORE_DOCS) { + System.out.println("DEBUG: New window first docID: " + firstDoc + " (min was: " + min + ")"); + return firstDoc; // CancellableBulkScorer will use this as new min + } + } + + // Score existing conjunction within [min, max) range + return scoreExistingConjunction(collector, acceptDocs, min, max); + } + + private int scoreExistingConjunction(LeafCollector collector, Bits acceptDocs, int min, int max) throws IOException { + if (globalConjunction == null) { + return DocIdSetIterator.NO_MORE_DOCS; + } + + // Position the iterator correctly + if (globalConjunction.docID() < min) { + if (globalConjunction.docID() == min - 1) { + globalConjunction.nextDoc(); } else { - conjunctionDISI.advance(min); + globalConjunction.advance(min); } } + int collected = 0; int doc = -1; - // Score documents in the range [min, max) with early termination - for (doc = conjunctionDISI.docID(); doc < max; doc = conjunctionDISI.nextDoc()) { - // Early termination when we reach the threshold + // Score documents in the range [min, max) + for (doc = globalConjunction.docID(); doc < max; doc = globalConjunction.nextDoc()) { if (totalCollected >= 10000) { - if (!expansionStopped) { - // Stop all ResumableDISI instances from expanding further - for (DocIdSetIterator iter : clauseIterators) { - if (iter instanceof ResumableDISI disi) { - disi.stopExpansion(); - } - } - expansionStopped = true; - System.out.println("DEBUG: Stopped expansion for all ResumableDISI at " + totalCollected + " hits"); - System.out.println("DEBUG: Conjunction docIDs: " + conjunctionDocIds); - - } - return DocIdSetIterator.NO_MORE_DOCS; // Exit the entire score method + System.out.println("DEBUG: Reached 10000 hits, stopping"); + return DocIdSetIterator.NO_MORE_DOCS; // Early termination } if (acceptDocs == null || acceptDocs.get(doc)) { @@ -295,16 +290,13 @@ public int score(LeafCollector collector, Bits acceptDocs, int min, int max) thr System.out.println("Total Collected: " + totalCollected + " Collected this window: " + collected); // Check if conjunction exhausted - if (conjunctionDISI.docID() == DocIdSetIterator.NO_MORE_DOCS) { + if (globalConjunction.docID() == DocIdSetIterator.NO_MORE_DOCS) { System.out.println("DEBUG: Conjunction exhausted at " + totalCollected + " total hits"); System.out.println("DEBUG: Conjunction docIDs: " + conjunctionDocIds); } - // System.out.println("Num conjunction hits " + collected + " (total: " + totalCollected + ")"); - - // Return the current iterator position (standard Lucene pattern) - System.out.println("Conjunction DISI current position after bulkscorer.score: " + conjunctionDISI.docID()); - return conjunctionDISI.docID(); + System.out.println("Conjunction DISI current position after bulkscorer.score: " + globalConjunction.docID()); + return globalConjunction.docID(); } @Override diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java index 699209309e48c..9817a51bec3e0 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java @@ -186,7 +186,7 @@ public void grow(int count) { public void visit(int docID) { // Log first docID if (docCount[0] == 0) { - System.out.println("First docID: " + docID); +// System.out.println("First docID: " + docID); } // firstDocIds.add(docID); @@ -194,7 +194,7 @@ public void visit(int docID) { docCount[0]++; // Log when we hit certain milestones if (docCount[0] >= 10200) { - System.out.println("Last docID at 10240: " + docID); +// System.out.println("Last docID at 10240: " + docID); } } @@ -207,7 +207,7 @@ public void visit(DocIdSetIterator iterator) throws IOException { public void visit(IntsRef ref) { // Log first docID from bulk visit if (docCount[0] == 0) { - System.out.println("First docID (bulk): " + ref.ints[0]); +// System.out.println("First docID (bulk): " + ref.ints[0]); } // // // Collect first 10240 docIDs for validation @@ -372,11 +372,11 @@ public void intersectLeftIterative( return; } } - if (docCount[0] >= size) { - state.setPointTree(pointTree); - state.setInProgress(true); - return; - } +// if (docCount[0] >= size) { +// state.setPointTree(pointTree); +// state.setInProgress(true); +// return; +// } } } @@ -518,72 +518,38 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti @Override public Scorer get(long leadCost) throws IOException { - System.out.println("DEBUG: ApproximatePointRangeQuery.get() called - BKD state management disabled"); - - // // Create fresh DocIdSetBuilder and visitor for each call - result = new DocIdSetBuilder(reader.maxDoc(), values); - visitor = getIntersectVisitor(result, docCount); - - // Simple approach: always traverse from root - System.out.println("DEBUG: Starting BKD traversal from root, target size: " + size); - // values.intersect(visitor); - - // Check if we have a saved tree and we're not exhausted - if (state.getPointTree() == null && !state.isExhausted()) { - // First call - start from the root - System.out.println("DEBUG: First call - starting from root"); - state.setPointTree(values.getPointTree()); - docCount[0] = 0; // Reset doc count for first call - } else if (state.getPointTree() != null && !state.isExhausted()) { - // Resume from where we left off - reset docCount for this expansion - System.out.println("DEBUG: Resuming from saved state"); - docCount[0] = 0; // Reset doc count for each expansion - } else { - System.out.println("DEBUG: BKD state is exhausted or invalid"); - } + // Use leadCost as dynamic size if it's reasonable, otherwise use original size + int dynamicSize = (leadCost > 0 && leadCost < Integer.MAX_VALUE) ? (int) leadCost : size; + return getWithSize(dynamicSize); + } + + public Scorer getWithSize(int dynamicSize) throws IOException { + // Temporarily update size for this call + int originalSize = size; + size = dynamicSize; - // Only process if we're not exhausted and have a valid tree - if (!state.isExhausted() && state.getPointTree() != null) { - System.out.println("DEBUG: Processing BKD tree, current docCount: " + docCount[0] + ", size: " + size); - // Reset the in-progress flag before processing - state.setInProgress(false); - - // Use intersectLeftIterative for resumable traversal - System.out.println("DEBUG: Starting intersectLeftIterative, current docCount: " + docCount[0]); - if (!context.isTopLevel) { - intersectLeftIterativeNew(visitor, state.getPointTree(), docCount, state); - } else { - intersectLeft(visitor, state.getPointTree(), docCount); - // values.intersect(visitor); - } - System.out.println("DEBUG: After intersectLeftIterative, docCount: " + docCount[0] + ", size: " + size); - - // Check if we collected enough documents - if (docCount[0] >= size) { - state.setInProgress(true); - state.needMore = true; - } - // Note: exhaustion is now handled inside intersectLeftIterative - - // System.out.println("DEBUG: BKD traversal completed, found " + docCount[0] + " documents"); - - // Track total documents added - totalDocsAdded += docCount[0]; - totalGetCalls++; - System.out.println( - "DEBUG: Total docs added across all calls: " + totalDocsAdded + " (call #" + totalGetCalls + ")" - ); - - // If we didn't collect any documents and we're not in progress, we've exhausted the tree - - } else { - System.out.println("DEBUG: Skipping BKD processing - exhausted: " + state.isExhausted()); + try { + System.out.println("DEBUG: ApproximatePointRangeQuery.get() called with dynamic size: " + dynamicSize); + + // For windowed approach, create fresh iterator without ResumableDISI state + DocIdSetBuilder freshResult = new DocIdSetBuilder(reader.maxDoc(), values); + long[] freshDocCount = new long[1]; + PointValues.IntersectVisitor freshVisitor = getIntersectVisitor(freshResult, freshDocCount); + + System.out.println("DEBUG: Starting fresh BKD traversal from root, target size: " + size); + + // Always start fresh traversal from root + intersectLeft(values.getPointTree(), freshVisitor, freshDocCount); + + System.out.println("DEBUG: Fresh traversal completed, docCount: " + freshDocCount[0]); + + DocIdSetIterator iterator = freshResult.build().iterator(); + System.out.println("DEBUG: Built fresh iterator with cost: " + iterator.cost()); + return new ConstantScoreScorer(score(), scoreMode, iterator); + } finally { + // Restore original size + size = originalSize; } - DocIdSetIterator iterator = result.build().iterator(); - result = null; - System.out.println("DEBUG: Built iterator with cost: " + iterator.cost()); - // System.out.println("DocIDs collected: "+firstDocIds); - return new ConstantScoreScorer(score(), scoreMode, iterator); } @Override From 17daf2d26223120b559888237a826d00f4029574 Mon Sep 17 00:00:00 2001 From: Sawan Srivastava Date: Tue, 12 Aug 2025 22:32:07 +0000 Subject: [PATCH 27/38] cleaned up code and fixed single clause bool Signed-off-by: Sawan Srivastava --- CHANGELOG.md | 29 ++ .../index/query/BoolQueryBuilder.java | 9 +- .../approximate/ApproximateBooleanQuery.java | 39 +- .../ApproximateBooleanScorerSupplier.java | 99 ++--- .../ApproximateConjunctionDISI.java | 142 ------- .../ApproximateConjunctionScorer.java | 78 ---- .../ApproximatePointRangeQuery.java | 171 +++------ .../approximate/ApproximateScoreQuery.java | 34 +- .../search/approximate/ResumableDISI.java | 350 ------------------ 9 files changed, 137 insertions(+), 814 deletions(-) delete mode 100644 server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionDISI.java delete mode 100644 server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionScorer.java delete mode 100644 server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 8599977452207..d620d2c79ca10 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,35 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Implement GRPC ConstantScoreQuery, FuzzyQuery, MatchBoolPrefixQuery, MatchPhrasePrefix, PrefixQuery, MatchQuery ([#19854](https://github.com/opensearch-project/OpenSearch/pull/19854)) - Add async periodic flush task support for pull-based ingestion ([#19878](https://github.com/opensearch-project/OpenSearch/pull/19878)) +- Add support for Warm Indices Write Block on Flood Watermark breach ([#18375](https://github.com/opensearch-project/OpenSearch/pull/18375)) +- Add support for custom index name resolver from cluster plugin ([#18593](https://github.com/opensearch-project/OpenSearch/pull/18593)) +- Rename WorkloadGroupTestUtil to WorkloadManagementTestUtil ([#18709](https://github.com/opensearch-project/OpenSearch/pull/18709)) +- Disallow resize for Warm Index, add Parameterized ITs for close in remote store ([#18686](https://github.com/opensearch-project/OpenSearch/pull/18686)) +- Ability to run Code Coverage with Gradle and produce the jacoco reports locally ([#18509](https://github.com/opensearch-project/OpenSearch/issues/18509)) +- [Workload Management] Update logging and Javadoc, rename QueryGroup to WorkloadGroup ([#18711](https://github.com/opensearch-project/OpenSearch/issues/18711)) +- Add NodeResourceUsageStats to ClusterInfo ([#18480](https://github.com/opensearch-project/OpenSearch/issues/18472)) +- Introduce SecureHttpTransportParameters experimental API (to complement SecureTransportParameters counterpart) ([#18572](https://github.com/opensearch-project/OpenSearch/issues/18572)) +- Create equivalents of JSM's AccessController in the java agent ([#18346](https://github.com/opensearch-project/OpenSearch/issues/18346)) +- [WLM] Add WLM mode validation for workload group CRUD requests ([#18652](https://github.com/opensearch-project/OpenSearch/issues/18652)) +- Introduced a new cluster-level API to fetch remote store metadata (segments and translogs) for each shard of an index. ([#18257](https://github.com/opensearch-project/OpenSearch/pull/18257)) +- Add last index request timestamp columns to the `_cat/indices` API. ([10766](https://github.com/opensearch-project/OpenSearch/issues/10766)) +- Introduce a new pull-based ingestion plugin for file-based indexing (for local testing) ([#18591](https://github.com/opensearch-project/OpenSearch/pull/18591)) +- Add support for search pipeline in search and msearch template ([#18564](https://github.com/opensearch-project/OpenSearch/pull/18564)) +- [Workload Management] Modify logging message in WorkloadGroupService ([#18712](https://github.com/opensearch-project/OpenSearch/pull/18712)) +- Add BooleanQuery rewrite moving constant-scoring must clauses to filter clauses ([#18510](https://github.com/opensearch-project/OpenSearch/issues/18510)) +- Add functionality for plugins to inject QueryCollectorContext during QueryPhase ([#18637](https://github.com/opensearch-project/OpenSearch/pull/18637)) +- Add support for non-timing info in profiler ([#18460](https://github.com/opensearch-project/OpenSearch/issues/18460)) +- Extend Approximation Framework to other numeric types ([#18530](https://github.com/opensearch-project/OpenSearch/issues/18530)) +- Add Semantic Version field type mapper and extensive unit tests([#18454](https://github.com/opensearch-project/OpenSearch/pull/18454)) +- Pass index settings to system ingest processor factories. ([#18708](https://github.com/opensearch-project/OpenSearch/pull/18708)) +- Include named queries from rescore contexts in matched_queries array ([#18697](https://github.com/opensearch-project/OpenSearch/pull/18697)) +- Add the configurable limit on rule cardinality ([#18663](https://github.com/opensearch-project/OpenSearch/pull/18663)) +- [Experimental] Start in "clusterless" mode if a clusterless ClusterPlugin is loaded ([#18479](https://github.com/opensearch-project/OpenSearch/pull/18479)) +- [Star-Tree] Add star-tree search related stats ([#18707](https://github.com/opensearch-project/OpenSearch/pull/18707)) +- Add support for plugins to profile information ([#18656](https://github.com/opensearch-project/OpenSearch/pull/18656)) +- Add support for Combined Fields query ([#18724](https://github.com/opensearch-project/OpenSearch/pull/18724)) +- Multifold Improvement in Multi-Clause Boolean Query, Window Scoring Approach ([#19046](https://github.com/opensearch-project/OpenSearch/pull/19046)) + ### Changed - Faster `terms` query creation for `keyword` field with index and docValues enabled ([#19350](https://github.com/opensearch-project/OpenSearch/pull/19350)) - Refactor to move prepareIndex and prepareDelete methods to Engine class ([#19551](https://github.com/opensearch-project/OpenSearch/pull/19551)) diff --git a/server/src/main/java/org/opensearch/index/query/BoolQueryBuilder.java b/server/src/main/java/org/opensearch/index/query/BoolQueryBuilder.java index a9564db7b2c1d..784e23afe55ea 100644 --- a/server/src/main/java/org/opensearch/index/query/BoolQueryBuilder.java +++ b/server/src/main/java/org/opensearch/index/query/BoolQueryBuilder.java @@ -343,10 +343,11 @@ protected Query doToQuery(QueryShardContext context) throws IOException { } // TODO: Figure out why multi-clause breaks testPhrasePrefix() in HighlighterWithAnalyzersTests.java - // return ((BooleanQuery) query).clauses().size() == 1 - // ? new ApproximateScoreQuery(query, new ApproximateBooleanQuery((BooleanQuery) query)) - // : query; - return new ApproximateScoreQuery(query, new ApproximateBooleanQuery((BooleanQuery) query)); + if (query instanceof BooleanQuery boolQuery) { + return new ApproximateScoreQuery(query, new ApproximateBooleanQuery(boolQuery)); + } + + return query; } private static void addBooleanClauses( diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java index e7441e628944e..ed9718e0dfef7 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java @@ -48,21 +48,6 @@ public BooleanQuery getBooleanQuery() { return boolQuery; } - public Query getClauseQuery() { - return clauses.get(0).query(); - } - - public static Query unwrap(Query unwrapBoolQuery) { - Query clauseQuery = unwrapBoolQuery instanceof ApproximateBooleanQuery - ? ((ApproximateBooleanQuery) unwrapBoolQuery).getClauseQuery() - : ((BooleanQuery) unwrapBoolQuery).clauses().get(0).query(); - if (clauseQuery instanceof ApproximateBooleanQuery nestedBool) { - return unwrap(nestedBool); - } else { - return clauseQuery; - } - } - @Override protected boolean canApproximate(SearchContext context) { if (context == null) { @@ -79,19 +64,17 @@ protected boolean canApproximate(SearchContext context) { return false; } - // // For single clause boolean queries, check if the clause can be approximated - // if (clauses.size() == 1 && clauses.get(0).occur() != BooleanClause.Occur.MUST_NOT) { - // BooleanClause singleClause = clauses.get(0); - // Query clauseQuery = singleClause.query(); - // - // // If the clause is already an ApproximateScoreQuery, we can approximate + set context - // if (clauseQuery instanceof ApproximateScoreQuery approximateScoreQuery) { - // if (approximateScoreQuery.getApproximationQuery() instanceof ApproximateBooleanQuery nestedBool) { - // return nestedBool.canApproximate(context); - // } - // return approximateScoreQuery.getApproximationQuery().canApproximate(context); - // } - // } + // For single clause boolean queries, check if the clause can be approximated + if (clauses.size() == 1 && clauses.get(0).occur() != BooleanClause.Occur.MUST_NOT) { + // If the clause is already an ApproximateScoreQuery, we can approximate + set context + if (clauses.get(0).query() instanceof ApproximateScoreQuery approximateScoreQuery) { + if (approximateScoreQuery.getApproximationQuery() instanceof ApproximateBooleanQuery nestedBool) { + return nestedBool.canApproximate(context); + } + return approximateScoreQuery.getApproximationQuery().canApproximate(context); + } + return false; + } // return clauses.size() > 1 && clauses.stream().allMatch(clause -> clause.occur() == BooleanClause.Occur.FILTER); return clauses.stream().allMatch(clause -> clause.occur() == BooleanClause.Occur.FILTER); diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java index 0a9a9a22d921e..5753bfda815c3 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java @@ -26,7 +26,6 @@ public class ApproximateBooleanScorerSupplier extends ScorerSupplier { private final ScoreMode scoreMode; private final float boost; private final int threshold; - private final LeafReaderContext context; private long cost = -1; /** @@ -51,7 +50,6 @@ public ApproximateBooleanScorerSupplier( this.scoreMode = scoreMode; this.boost = boost; this.threshold = threshold; - this.context = context; // Store weights and cache their suppliers for (Weight clauseWeight : clauseWeights) { @@ -77,19 +75,8 @@ public Scorer get(long leadCost) throws IOException { // Create appropriate iterators for each clause - ResumableDISI only for approximatable queries List clauseIterators = new ArrayList<>(clauseWeights.size()); for (int i = 0; i < clauseWeights.size(); i++) { - Weight weight = clauseWeights.get(i); - ScorerSupplier supplier = cachedSuppliers.get(i); // Use cached supplier - Query query = weight.getQuery(); - - if (query instanceof ApproximateQuery) { - // Use ResumableDISI for approximatable queries - ResumableDISI disi = new ResumableDISI(supplier); - clauseIterators.add(disi); - } else { - // Use regular DocIdSetIterator for non-approximatable queries - Scorer scorer = supplier.get(leadCost); - clauseIterators.add(scorer.iterator()); - } + // Use regular DocIdSetIterator for non-approximatable queries + clauseIterators.add(cachedSuppliers.get(i).get(leadCost).iterator()); } // Use Lucene's ConjunctionUtils to create the conjunction @@ -155,8 +142,7 @@ public BulkScorer bulkScorer() throws IOException { minCost = Math.min(minCost, cost); maxCost = Math.max(maxCost, cost); } - final int initialWindowSize = (int) Math.min(minCost, maxCost >> 7); // max(costs)/2^7 - System.out.println("DEBUG: Window heuristic - minCost: " + minCost + ", maxCost: " + maxCost + ", initialWindowSize: " + initialWindowSize); + final int initialWindowSize = Math.max((1 << 15), (int) Math.min(minCost, maxCost / (1 << 7))); // Ensure minimum 10k // Create a simple scorer for the collector (will be used by windowed approach) Scorer scorer = new Scorer() { @@ -185,8 +171,6 @@ public int docID() { // Create a simple bulk scorer that wraps the conjunction return new BulkScorer() { private int totalCollected = 0; - private boolean expansionStopped = false; - private final List conjunctionDocIds = new ArrayList<>(); // Track total hits across all score() calls // Windowed approach state private int currentWindowSize = initialWindowSize; @@ -198,25 +182,18 @@ private List rebuildIteratorsWithWindowSize(int windowSize) th Weight weight = clauseWeights.get(i); ScorerSupplier supplier = cachedSuppliers.get(i); // Use cached supplier Query query = weight.getQuery(); - - if (query instanceof ApproximateQuery) { + + if (query instanceof ApproximatePointRangeQuery approxQuery) { // For approximatable queries, try to use the window size - if (query instanceof ApproximatePointRangeQuery) { - ApproximatePointRangeQuery approxQuery = (ApproximatePointRangeQuery) query; - // Temporarily set the size - int originalSize = approxQuery.getSize(); - approxQuery.setSize(windowSize); - try { - Scorer scorer = supplier.get(windowSize); - newIterators.add(scorer.iterator()); - } finally { - // Restore original size - approxQuery.setSize(originalSize); - } - } else { - // Other approximate queries - use ResumableDISI - ResumableDISI disi = new ResumableDISI(supplier); - newIterators.add(disi); + // Temporarily set the size + int originalSize = approxQuery.getSize(); + approxQuery.setSize(windowSize); + try { + Scorer scorer = supplier.get(windowSize); + newIterators.add(scorer.iterator()); + } finally { + // Restore original size + approxQuery.setSize(originalSize); } } else { // Regular queries use full cost @@ -229,37 +206,34 @@ private List rebuildIteratorsWithWindowSize(int windowSize) th @Override public int score(LeafCollector collector, Bits acceptDocs, int min, int max) throws IOException { - System.out.println("bulkscorer.score called with min: " + min + " and max: " + max); collector.setScorer(scorer); - + // Check if we need to expand window - if (totalCollected < 10000 && - (globalConjunction == null || globalConjunction.docID() == DocIdSetIterator.NO_MORE_DOCS)) { - - System.out.println("DEBUG: Expanding window from " + currentWindowSize + " to " + (currentWindowSize * 3)); + if (totalCollected < 10000 && (globalConjunction == null || globalConjunction.docID() == DocIdSetIterator.NO_MORE_DOCS)) { currentWindowSize *= 3; - + // Rebuild iterators with new window size List newIterators = rebuildIteratorsWithWindowSize(currentWindowSize); globalConjunction = ConjunctionUtils.intersectIterators(newIterators); - + // Return first docID from new conjunction (could be < min) int firstDoc = globalConjunction.nextDoc(); if (firstDoc != DocIdSetIterator.NO_MORE_DOCS) { - System.out.println("DEBUG: New window first docID: " + firstDoc + " (min was: " + min + ")"); return firstDoc; // CancellableBulkScorer will use this as new min - } + } else {} } - + // Score existing conjunction within [min, max) range - return scoreExistingConjunction(collector, acceptDocs, min, max); + int result = scoreExistingConjunction(collector, acceptDocs, min, max); + + return result; } - + private int scoreExistingConjunction(LeafCollector collector, Bits acceptDocs, int min, int max) throws IOException { if (globalConjunction == null) { return DocIdSetIterator.NO_MORE_DOCS; } - + // Position the iterator correctly if (globalConjunction.docID() < min) { if (globalConjunction.docID() == min - 1) { @@ -275,7 +249,6 @@ private int scoreExistingConjunction(LeafCollector collector, Bits acceptDocs, i // Score documents in the range [min, max) for (doc = globalConjunction.docID(); doc < max; doc = globalConjunction.nextDoc()) { if (totalCollected >= 10000) { - System.out.println("DEBUG: Reached 10000 hits, stopping"); return DocIdSetIterator.NO_MORE_DOCS; // Early termination } @@ -283,19 +256,29 @@ private int scoreExistingConjunction(LeafCollector collector, Bits acceptDocs, i collector.collect(doc); collected++; totalCollected++; - conjunctionDocIds.add(doc); } } - System.out.println("Total Collected: " + totalCollected + " Collected this window: " + collected); - // Check if conjunction exhausted if (globalConjunction.docID() == DocIdSetIterator.NO_MORE_DOCS) { - System.out.println("DEBUG: Conjunction exhausted at " + totalCollected + " total hits"); - System.out.println("DEBUG: Conjunction docIDs: " + conjunctionDocIds); + + // If we need more hits, expand immediately + if (totalCollected < 10000) { + currentWindowSize *= 3; + + try { + List newIterators = rebuildIteratorsWithWindowSize(currentWindowSize); + globalConjunction = ConjunctionUtils.intersectIterators(newIterators); + + int firstDoc = globalConjunction.nextDoc(); + if (firstDoc != DocIdSetIterator.NO_MORE_DOCS) { + return firstDoc; // Return new starting point + } + } catch (IOException e) {} + } + } - System.out.println("Conjunction DISI current position after bulkscorer.score: " + globalConjunction.docID()); return globalConjunction.docID(); } diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionDISI.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionDISI.java deleted file mode 100644 index a7712838e12ae..0000000000000 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionDISI.java +++ /dev/null @@ -1,142 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.search.approximate; - -import org.apache.lucene.search.DocIdSetIterator; - -import java.io.IOException; -import java.util.List; - -/** - * A conjunction of DocIdSetIterators with support for ResumableDISI expansion. - * Closely mirrors Lucene's ConjunctionDISI architecture with lead1, lead2, and others. - */ -public class ApproximateConjunctionDISI extends DocIdSetIterator { - - final DocIdSetIterator lead1, lead2; - final DocIdSetIterator[] others; - - private final List allIterators; - - public ApproximateConjunctionDISI(List iterators) { - if (iterators.size() < 2) { - throw new IllegalArgumentException("Cannot make a ConjunctionDISI of less than 2 iterators"); - } - - this.allIterators = iterators; - - // Follow Lucene's exact structure - this.lead1 = iterators.get(0); - this.lead2 = iterators.get(1); - this.others = iterators.subList(2, iterators.size()).toArray(new DocIdSetIterator[0]); - } - - @Override - public int docID() { - return lead1.docID(); - } - - @Override - public int nextDoc() throws IOException { - return doNext(lead1.nextDoc()); - } - - @Override - public int advance(int target) throws IOException { - return doNext(lead1.advance(target)); - } - - /** - * Core conjunction logic adapted from Lucene's ConjunctionDISI.doNext() - * with resumable expansion support. - */ - private int doNext(int doc) throws IOException { - advanceHead: for (;;) { - // Handle NO_MORE_DOCS with resumable expansion - if (doc == NO_MORE_DOCS) { - if (tryExpandResumableDISIs()) { - // After expansion, get the next document from lead1 - doc = lead1.nextDoc(); - if (doc == NO_MORE_DOCS) { - return NO_MORE_DOCS; // Truly exhausted - } - // Continue with the new document - } else { - return NO_MORE_DOCS; // No expansion possible - } - } - - // Find agreement between the two iterators with the lower costs - // We special case them because they do not need the - // 'other.docID() < doc' check that the 'others' iterators need - final int next2 = lead2.advance(doc); - if (next2 != doc) { - doc = lead1.advance(next2); - if (next2 != doc) { - continue; - } - } - - // Then find agreement with other iterators - for (DocIdSetIterator other : others) { - // other.docID() may already be equal to doc if we "continued advanceHead" - // on the previous iteration and the advance on the lead exactly matched. - if (other.docID() < doc) { - final int next = other.advance(doc); - - if (next > doc) { - // iterator beyond the current doc - advance lead and continue to the new highest doc. - doc = lead1.advance(next); - continue advanceHead; - } - } - } - - // Success - all iterators are on the same doc - return doc; - } - } - - /** - * Try to expand ResumableDISIs when we hit NO_MORE_DOCS - * @return true if any ResumableDISI was expanded - */ - private boolean tryExpandResumableDISIs() throws IOException { - boolean anyExpanded = false; - - // Check all iterators for expansion - for (DocIdSetIterator iterator : allIterators) { - if (iterator instanceof ResumableDISI) { - ResumableDISI resumable = (ResumableDISI) iterator; - if (!resumable.isExhausted()) { - // resumable.resetForNextBatch(); - anyExpanded = true; - } - } - } - - return anyExpanded; - } - - @Override - public long cost() { - long minCost = Long.MAX_VALUE; - for (DocIdSetIterator iterator : allIterators) { - minCost = Math.min(minCost, iterator.cost()); - } - return minCost; - } - - /** - * Reset method for compatibility (no longer needed with new architecture) - */ - public void resetAfterExpansion() throws IOException { - // No-op - expansion is now handled directly in doNext() - } -} diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionScorer.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionScorer.java deleted file mode 100644 index 5e27552f07bec..0000000000000 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateConjunctionScorer.java +++ /dev/null @@ -1,78 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.search.approximate; - -import org.apache.lucene.search.DocIdSetIterator; -import org.apache.lucene.search.ScoreMode; -import org.apache.lucene.search.Scorer; -import org.apache.lucene.search.TwoPhaseIterator; - -import java.io.IOException; -import java.util.List; - -/** - * A custom Scorer that manages an ApproximateConjunctionDISI. - * This class creates and manages an ApproximateConjunctionDISI to score documents - * that match all clauses in a boolean query. - */ -public class ApproximateConjunctionScorer extends Scorer { - private final ApproximateConjunctionDISI approximateConjunctionDISI; - private final float score; - - /** - * Creates a new ApproximateConjunctionScorer. - * - * @param boost The boost factor - * @param scoreMode The score mode - * @param iterators The iterators to coordinate (mix of ResumableDISI and regular DocIdSetIterator) - */ - public ApproximateConjunctionScorer(float boost, ScoreMode scoreMode, List iterators) { - // Scorer doesn't have a constructor that takes arguments - this.approximateConjunctionDISI = new ApproximateConjunctionDISI(iterators); - this.score = boost; - } - - /** - * Creates a new ApproximateConjunctionScorer with an existing conjunction. - * - * @param boost The boost factor - * @param scoreMode The score mode - * @param conjunctionDISI The existing conjunction to reuse - */ - public ApproximateConjunctionScorer(float boost, ScoreMode scoreMode, ApproximateConjunctionDISI conjunctionDISI) { - // Reuse the existing conjunction instead of creating a new one - this.approximateConjunctionDISI = conjunctionDISI; - this.score = boost; - } - - @Override - public DocIdSetIterator iterator() { - return approximateConjunctionDISI; - } - - @Override - public float score() throws IOException { - return 0.0f; - } - - @Override - public float getMaxScore(int upTo) throws IOException { - return 0.0f; - } - - @Override - public int docID() { - return approximateConjunctionDISI.docID(); - } - - @Override - public TwoPhaseIterator twoPhaseIterator() { - return null; // No two-phase iteration needed for conjunction - } -} diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java index 9817a51bec3e0..73662d8bbaa6d 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java @@ -38,10 +38,7 @@ import org.opensearch.search.sort.SortOrder; import java.io.IOException; -import java.util.ArrayList; -import java.util.List; import java.util.Objects; -import java.util.Stack; import java.util.function.Function; /** @@ -49,14 +46,6 @@ * after {@code size} is hit */ public class ApproximatePointRangeQuery extends ApproximateQuery { - // Track total documents across all get() calls - private static int totalDocsAdded = 0; - private static int totalGetCalls = 0; - - // Store first 10k docIDs for validation - private static final List firstDocIds = new ArrayList<>(); - private static boolean docIdsCollected = false; - public static final Function LONG_FORMAT = bytes -> Long.toString(LongPoint.decodeDimension(bytes, 0)); public static final Function INT_FORMAT = bytes -> Integer.toString(IntPoint.decodeDimension(bytes, 0)); public static final Function HALF_FLOAT_FORMAT = bytes -> Float.toString(HalfFloatPoint.decodeDimension(bytes, 0)); @@ -184,18 +173,8 @@ public void grow(int count) { @Override public void visit(int docID) { - // Log first docID - if (docCount[0] == 0) { -// System.out.println("First docID: " + docID); - } - // firstDocIds.add(docID); - adder.add(docID); docCount[0]++; - // Log when we hit certain milestones - if (docCount[0] >= 10200) { -// System.out.println("Last docID at 10240: " + docID); - } } @Override @@ -205,23 +184,8 @@ public void visit(DocIdSetIterator iterator) throws IOException { @Override public void visit(IntsRef ref) { - // Log first docID from bulk visit - if (docCount[0] == 0) { -// System.out.println("First docID (bulk): " + ref.ints[0]); - } - // - // // Collect first 10240 docIDs for validation - - // for (int i = 0; i < ref.length; i++) { - // firstDocIds.add(ref.ints[i]); - // } - adder.add(ref); docCount[0] += ref.length; - // Log last docID from bulk visit when we hit milestone - if (docCount[0] >= 10240) { - // System.out.println("Last docID (bulk) at " + docCount[0] + ": " + ref.ints[ref.length - 1]); - } } @Override @@ -339,7 +303,7 @@ public void intersectLeftIterative( PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree, long[] docCount, - ResumableDISI.BKDState state + BKDState state ) throws IOException { while (true) { @@ -372,81 +336,12 @@ public void intersectLeftIterative( return; } } -// if (docCount[0] >= size) { -// state.setPointTree(pointTree); -// state.setInProgress(true); -// return; -// } - } - } - - public void intersectLeftIterativeNew( - PointValues.IntersectVisitor visitor, - PointValues.PointTree pointTree, - long[] docCount, - ResumableDISI.BKDState state - ) throws IOException { - - // Stack to track nodes to process - Stack nodeStack = new Stack<>(); - nodeStack.push(pointTree.clone()); - - while (!nodeStack.isEmpty() && docCount[0] < size) { - PointValues.PointTree currentTree = nodeStack.pop(); - - if (docCount[0] >= size) { - continue; - } - - PointValues.Relation r = visitor.compare(currentTree.getMinPackedValue(), currentTree.getMaxPackedValue()); - if (r == PointValues.Relation.CELL_OUTSIDE_QUERY) { - continue; - } - - // Handle leaf nodes - if (currentTree.moveToChild() == false) { - if (r == PointValues.Relation.CELL_INSIDE_QUERY) { - currentTree.visitDocIDs(visitor); - } else { - // CELL_CROSSES_QUERY - currentTree.visitDocValues(visitor); - } - continue; - } - - // Internal node processing - PointValues.PointTree leftChild = currentTree.clone(); - PointValues.PointTree rightChild = null; - - // Check if right sibling exists - if (currentTree.moveToSibling()) { - rightChild = currentTree.clone(); - } - - // For CELL_INSIDE_QUERY, check if we can skip right child - if (r == PointValues.Relation.CELL_INSIDE_QUERY && rightChild != null) { - long leftSize = leftChild.size(); - long needed = size - docCount[0]; - - if (leftSize >= needed) { - // Process only left child - nodeStack.push(leftChild); - continue; - } - } - - // Process both children: push right first (so left is processed first due to stack LIFO) - if (rightChild != null) { - nodeStack.push(rightChild); - } - nodeStack.push(leftChild); - if (docCount[0] >= size) { - state.setPointTree(currentTree); - state.setInProgress(true); - return; - } + // if (docCount[0] >= size) { + // state.setPointTree(pointTree); + // state.setInProgress(true); + // return; + // } } - } // custom intersect visitor to walk the right of tree (from rightmost leaf going left) @@ -510,7 +405,6 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti if (sortOrder == null || sortOrder.equals(SortOrder.ASC)) { return new ScorerSupplier() { - ResumableDISI.BKDState state = new ResumableDISI.BKDState(); // Keep a visitor for cost estimation only DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values); PointValues.IntersectVisitor visitor = getIntersectVisitor(result, docCount); @@ -518,9 +412,14 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti @Override public Scorer get(long leadCost) throws IOException { - // Use leadCost as dynamic size if it's reasonable, otherwise use original size - int dynamicSize = (leadCost > 0 && leadCost < Integer.MAX_VALUE) ? (int) leadCost : size; - return getWithSize(dynamicSize); + if (!context.isTopLevel) { + // Use leadCost as dynamic size if it's reasonable, otherwise use original size + int dynamicSize = (leadCost > 0 && leadCost < Integer.MAX_VALUE) ? (int) leadCost : size; + return getWithSize(dynamicSize); + } else { + // For top-level queries, use standard approach + return getWithSize(size); + } } public Scorer getWithSize(int dynamicSize) throws IOException { @@ -529,22 +428,16 @@ public Scorer getWithSize(int dynamicSize) throws IOException { size = dynamicSize; try { - System.out.println("DEBUG: ApproximatePointRangeQuery.get() called with dynamic size: " + dynamicSize); // For windowed approach, create fresh iterator without ResumableDISI state DocIdSetBuilder freshResult = new DocIdSetBuilder(reader.maxDoc(), values); long[] freshDocCount = new long[1]; PointValues.IntersectVisitor freshVisitor = getIntersectVisitor(freshResult, freshDocCount); - System.out.println("DEBUG: Starting fresh BKD traversal from root, target size: " + size); - // Always start fresh traversal from root intersectLeft(values.getPointTree(), freshVisitor, freshDocCount); - System.out.println("DEBUG: Fresh traversal completed, docCount: " + freshDocCount[0]); - DocIdSetIterator iterator = freshResult.build().iterator(); - System.out.println("DEBUG: Built fresh iterator with cost: " + iterator.cost()); return new ConstantScoreScorer(score(), scoreMode, iterator); } finally { // Restore original size @@ -621,8 +514,6 @@ private byte[] computeEffectiveBound(SearchContext context, boolean isLowerBound @Override public boolean canApproximate(SearchContext context) { - // System.out.println("canApproximate: false"); - // return false; if (context == null) { return false; } @@ -703,4 +594,38 @@ public final String toString(String field) { return sb.toString(); } + + /** + * Class to track the state of BKD tree traversal. + */ + public static class BKDState { + private PointValues.PointTree currentTree; + private boolean isExhausted = false; + private boolean inProgress = false; + + public PointValues.PointTree getPointTree() { + return currentTree; + } + + public void setPointTree(PointValues.PointTree tree) { + this.currentTree = tree; + } + + public boolean isExhausted() { + return this.isExhausted; + } + + public void setExhausted(boolean exhausted) { + this.isExhausted = exhausted; + } + + public boolean isInProgress() { + return this.inProgress; + } + + public void setInProgress(boolean inProgress) { + this.inProgress = inProgress; + } + + } } diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java index bdb10e8967192..6ac8f1dc6afc0 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java @@ -9,7 +9,6 @@ package org.opensearch.search.approximate; import org.apache.lucene.search.BooleanClause; -import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; @@ -59,38 +58,11 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { public void setContext(SearchContext context) { resolvedQuery = approximationQuery.canApproximate(context) ? approximationQuery : originalQuery; - - boolean needsRewrite = false; - if (resolvedQuery instanceof ApproximateBooleanQuery appxBool) { - if (appxBool.getBooleanQuery().clauses().size() == 1) { - // For single-clause boolean queries, unwrap and process as before - resolvedQuery = ApproximateBooleanQuery.unwrap(resolvedQuery); - if (resolvedQuery instanceof ApproximateScoreQuery appxResolved) { - appxResolved.setContext(context); + for (BooleanClause boolClause : appxBool.boolQuery.clauses()) { + if (boolClause.query() instanceof ApproximateScoreQuery apprxQuery) { + apprxQuery.setContext(context); } - } else { - for (BooleanClause boolClause : appxBool.boolQuery.clauses()) { - if (boolClause.query() instanceof ApproximateScoreQuery apprxQuery) { - apprxQuery.setContext(context); - } - } - } - needsRewrite = true; - } else if (resolvedQuery instanceof BooleanQuery) { - resolvedQuery = ApproximateBooleanQuery.unwrap(resolvedQuery); - if (resolvedQuery instanceof ApproximateScoreQuery appxResolved) { - appxResolved.setContext(context); - } - needsRewrite = true; - } - - // Only rewrite boolean queries - if (needsRewrite) { - try { - resolvedQuery = resolvedQuery.rewrite(context.searcher()); - } catch (IOException e) { - throw new RuntimeException(e); } } } diff --git a/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java b/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java deleted file mode 100644 index 6265178871174..0000000000000 --- a/server/src/main/java/org/opensearch/search/approximate/ResumableDISI.java +++ /dev/null @@ -1,350 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.search.approximate; - -import org.apache.lucene.index.PointValues; -import org.apache.lucene.search.DocIdSetIterator; -import org.apache.lucene.search.Scorer; -import org.apache.lucene.search.ScorerSupplier; - -import java.io.IOException; - -/** - * A resumable DocIdSetIterator that internally expands when it reaches NO_MORE_DOCS. - * On the surface, this behaves identically to a regular DISI, but internally it can - * expand by scoring additional documents when needed. - * - * The expansion is completely internal - external callers see a normal DISI interface - * that continues to return documents even after initially hitting NO_MORE_DOCS. - */ -public class ResumableDISI extends DocIdSetIterator { - private static final int DEFAULT_EXPANSION_SIZE = 10_000; - - private final ScorerSupplier scorerSupplier; - private final int expansionSize; - - // Current state - private DocIdSetIterator currentDisi; - private int currentDocId = -1; - private boolean fullyExhausted = false; - private volatile boolean stopExpansion = false; - - private int documentsScored = 0; // Total documents scored across all expansions - private int documentsReturned = 0; // Count of documents returned by nextDoc() - - // Debug: Add a unique ID to distinguish between ResumableDISI instances - private static int instanceCounter = 0; - private final int instanceId; - - /** - * Creates a new ResumableDISI with the default expansion size of 10,000 documents. - * - * @param scorerSupplier The scorer supplier to get scorers from - */ - public ResumableDISI(ScorerSupplier scorerSupplier) { - this(scorerSupplier, DEFAULT_EXPANSION_SIZE); - } - - /** - * Creates a new ResumableDISI with the specified expansion size. - * - * @param scorerSupplier The scorer supplier to get scorers from - * @param expansionSize The number of documents to score in each expansion - */ - public ResumableDISI(ScorerSupplier scorerSupplier, int expansionSize) { - this.scorerSupplier = scorerSupplier; - this.expansionSize = expansionSize; - this.instanceId = ++instanceCounter; - try { - expandInternally(); - } catch (IOException e) { - throw new RuntimeException(e); - } - // System.out.println("DEBUG: Created ResumableDISI instance " + instanceId); - } - - @Override - public int docID() { - return currentDocId; - } - - @Override - public int nextDoc() throws IOException { - if (fullyExhausted) { - return NO_MORE_DOCS; - } - - // If we don't have a current iterator, get one - if (currentDisi == null) { - if (!expandInternally()) { - return NO_MORE_DOCS; - } - // Position the new iterator on its first document - int doc = currentDisi.nextDoc(); - if (doc != NO_MORE_DOCS) { - currentDocId = doc; - documentsReturned++; - return doc; - } else { - // Iterator was empty after all - fullyExhausted = true; - currentDocId = NO_MORE_DOCS; - return NO_MORE_DOCS; - } - } - - // Try to get the next document from current iterator - int doc = currentDisi.nextDoc(); - - if (doc != NO_MORE_DOCS) { - currentDocId = doc; - documentsReturned++; - return doc; - } - - // Current iterator exhausted, try to expand internally - if (expandInternally()) { - // Position the new iterator on its first document - doc = currentDisi.nextDoc(); - if (doc != NO_MORE_DOCS) { - currentDocId = doc; - documentsReturned++; - return doc; - } - } - - // No more expansion possible - // System.out.println("DEBUG: ResumableDISI " + instanceId + " - EXHAUSTED after returning " + documentsReturned + " total - // documents"); - currentDocId = NO_MORE_DOCS; - fullyExhausted = true; - return NO_MORE_DOCS; - } - - @Override - public int advance(int target) throws IOException { - // System.out.println("DEBUG: ResumableDISI " + instanceId + " - advance(" + target + ") called, currentDocId: " + currentDocId); - - if (fullyExhausted) { - System.out.println("DEBUG: ResumableDISI " + instanceId + " - advance() called but already exhausted"); - return NO_MORE_DOCS; - } - - // If target is NO_MORE_DOCS, no point in expanding - if (target == NO_MORE_DOCS) { - System.out.println("DEBUG: ResumableDISI " + instanceId + " - advance() target is NO_MORE_DOCS, marking as exhausted"); - fullyExhausted = true; - currentDocId = NO_MORE_DOCS; - return NO_MORE_DOCS; - } - - // If we don't have a current iterator, get one - if (currentDisi == null) { - if (!expandInternally()) { - // System.out.println("DEBUG: ResumableDISI " + instanceId + " - advance() expandInternally failed"); - return NO_MORE_DOCS; - } - // Position the new iterator and check if it meets target - int doc = currentDisi.nextDoc(); - if (doc != NO_MORE_DOCS) { - currentDocId = doc; - if (currentDocId >= target) { - return currentDocId; - } - // Otherwise, advance to target - doc = currentDisi.advance(target); - if (doc != NO_MORE_DOCS) { - currentDocId = doc; - return doc; - } - } - // Fall through to try expansion - } else { - // Try to advance current iterator - int doc = currentDisi.advance(target); - if (doc != NO_MORE_DOCS) { - currentDocId = doc; - // Don't increment documentsReturned - advance() skips documents, doesn't return them one by one - // System.out.println("DEBUG: ResumableDISI " + instanceId + " - advance() existing iterator found doc " + doc + " >= target - // " + target); - return doc; - } - // Current iterator exhausted, try to expand - // System.out.println("DEBUG: ResumableDISI " + instanceId + " - advance() current iterator exhausted, trying to expand"); - } - - // Current iterator exhausted, try to expand internally - while (expandInternally()) { - // Position the new iterator and check if it meets target - // Otherwise, advance to target - int doc = currentDisi.advance(target); - if (doc != NO_MORE_DOCS) { - currentDocId = doc; - return doc; - } - // int doc = currentDisi.nextDoc(); - // if (doc != NO_MORE_DOCS) { - // currentDocId = doc; - // if (currentDocId >= target) { - // return currentDocId; - // } - // // Otherwise, advance to target - // doc = currentDisi.advance(target); - // if (doc != NO_MORE_DOCS) { - // currentDocId = doc; - // return doc; - // } - // } - // This expansion didn't have a suitable document, try expanding again - // System.out.println("DEBUG: ResumableDISI " + instanceId + " - advance() expansion didn't find target " + target + ", trying - // next expansion"); - } - - // No more expansion possible - System.out.println( - "DEBUG: ResumableDISI " + instanceId + " - advance() EXHAUSTED after returning " + documentsReturned + " total documents" - ); - currentDocId = NO_MORE_DOCS; - fullyExhausted = true; - return NO_MORE_DOCS; - } - - /** - * Expands the iterator internally by getting a new scorer from the supplier. - * This is called when we hit NO_MORE_DOCS but more documents might be available. - * - * @return true if expansion was successful, false if fully exhausted - * @throws IOException If there's an error getting the scorer - */ - private boolean expandInternally() throws IOException { - if (fullyExhausted || stopExpansion) { - return false; - } - - // // For now, disable expansion after first call to test basic logic - // if (currentDisi != null) { - // System.out.println("DEBUG: ResumableDISI " + instanceId + " - expansion disabled, but NOT marking as exhausted"); - // // Don't set fullyExhausted = true here! Let the iterator continue with its current batch - // return false; - // } - - // Get a new scorer from the supplier - this will resume from saved BKD state - System.out.println("DEBUG: ResumableDISI " + instanceId + " - calling expandInternally"); - Scorer scorer = scorerSupplier.get(scorerSupplier.cost()); - if (scorer == null) { - fullyExhausted = true; - return false; - } - - currentDisi = scorer.iterator(); - documentsScored += expansionSize; // Track total documents scored - - // System.out.println("DEBUG: ResumableDISI " + instanceId + " - got iterator with " + currentDisi.cost() + " documents"); - - // Check if the iterator has any documents by looking at cost - if (currentDisi.cost() == 0) { - System.out.println( - "DEBUG: ResumableDISI " + instanceId + " - expandInternally got empty iterator (cost=0), marking as exhausted" - ); - fullyExhausted = true; - currentDocId = NO_MORE_DOCS; - return false; - } - - // Don't position the iterator - let nextDoc() or advance() handle that - return true; - } - - @Override - public long cost() { - return 10_000L; - } - - /** - * Signal this iterator to stop expanding and return NO_MORE_DOCS - */ - public void stopExpansion() { - stopExpansion = true; - } - - /** - * Returns whether this iterator has been fully exhausted. - * - * @return true if there are no more documents to score - */ - public boolean isExhausted() { - return fullyExhausted; - } - - /** - * Returns the total number of documents scored across all expansions. - * - * @return The total number of documents scored - */ - public int getDocumentsScored() { - return documentsScored; - } - - /** - * Class to track the state of BKD tree traversal. - */ - public static class BKDState { - private PointValues.PointTree currentTree; - private boolean isExhausted = false; - private long docCount = 0; - private boolean inProgress = false; - private boolean hasSetTree = false; - public boolean needMore = false; - - public PointValues.PointTree getPointTree() { - return currentTree; - } - - public void setPointTree(PointValues.PointTree tree) { - if (tree != null) { - this.currentTree = tree; - } else { - this.currentTree = null; - } - } - - public boolean isExhausted() { - return isExhausted; - } - - public void setExhausted(boolean exhausted) { - this.isExhausted = exhausted; - } - - public long getDocCount() { - return docCount; - } - - public void setDocCount(long count) { - this.docCount = count; - } - - public boolean isInProgress() { - return inProgress; - } - - public void setInProgress(boolean inProgress) { - this.inProgress = inProgress; - } - - public boolean hasSetTree() { - return hasSetTree; - } - - public void setHasSetTree(boolean hasSet) { - this.hasSetTree = hasSet; - } - - } -} From 29f9505fcaa10d5193c06fc9c59254d332d34c99 Mon Sep 17 00:00:00 2001 From: Sawan Srivastava Date: Wed, 13 Aug 2025 05:35:04 +0000 Subject: [PATCH 28/38] add proper bounds and license header Signed-off-by: Sawan Srivastava --- .../ApproximateBooleanScorerSupplier.java | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java index 5753bfda815c3..9bc9784035254 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java @@ -1,3 +1,11 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + package org.opensearch.search.approximate; import org.apache.lucene.index.LeafReaderContext; @@ -25,7 +33,7 @@ public class ApproximateBooleanScorerSupplier extends ScorerSupplier { private final List cachedSuppliers; // Cache suppliers to avoid repeated calls private final ScoreMode scoreMode; private final float boost; - private final int threshold; + private final int size; private long cost = -1; /** @@ -34,7 +42,7 @@ public class ApproximateBooleanScorerSupplier extends ScorerSupplier { * @param clauseWeights The weights for each clause in the boolean query * @param scoreMode The score mode * @param boost The boost factor - * @param threshold The threshold for early termination + * @param size The threshold for early termination * @param context The leaf reader context * @throws IOException If there's an error creating scorer suppliers */ @@ -42,14 +50,14 @@ public ApproximateBooleanScorerSupplier( List clauseWeights, ScoreMode scoreMode, float boost, - int threshold, + int size, LeafReaderContext context ) throws IOException { this.clauseWeights = new ArrayList<>(); this.cachedSuppliers = new ArrayList<>(); this.scoreMode = scoreMode; this.boost = boost; - this.threshold = threshold; + this.size = size; // Store weights and cache their suppliers for (Weight clauseWeight : clauseWeights) { @@ -209,7 +217,7 @@ public int score(LeafCollector collector, Bits acceptDocs, int min, int max) thr collector.setScorer(scorer); // Check if we need to expand window - if (totalCollected < 10000 && (globalConjunction == null || globalConjunction.docID() == DocIdSetIterator.NO_MORE_DOCS)) { + if (totalCollected < size && (globalConjunction == null || globalConjunction.docID() == DocIdSetIterator.NO_MORE_DOCS)) { currentWindowSize *= 3; // Rebuild iterators with new window size @@ -248,7 +256,7 @@ private int scoreExistingConjunction(LeafCollector collector, Bits acceptDocs, i // Score documents in the range [min, max) for (doc = globalConjunction.docID(); doc < max; doc = globalConjunction.nextDoc()) { - if (totalCollected >= 10000) { + if (totalCollected >= size) { return DocIdSetIterator.NO_MORE_DOCS; // Early termination } @@ -263,7 +271,7 @@ private int scoreExistingConjunction(LeafCollector collector, Bits acceptDocs, i if (globalConjunction.docID() == DocIdSetIterator.NO_MORE_DOCS) { // If we need more hits, expand immediately - if (totalCollected < 10000) { + if (totalCollected < size) { currentWindowSize *= 3; try { From bf9f58d9d3cbb417162b5a601be90919a288fbd9 Mon Sep 17 00:00:00 2001 From: Sawan Srivastava Date: Thu, 14 Aug 2025 03:33:49 +0000 Subject: [PATCH 29/38] fixed failing highlighter test + nested bool check Signed-off-by: Sawan Srivastava --- .../index/query/BoolQueryBuilder.java | 5 ++-- .../approximate/ApproximateBooleanQuery.java | 25 +++++++++++++++++-- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/server/src/main/java/org/opensearch/index/query/BoolQueryBuilder.java b/server/src/main/java/org/opensearch/index/query/BoolQueryBuilder.java index 784e23afe55ea..b440cf65762b6 100644 --- a/server/src/main/java/org/opensearch/index/query/BoolQueryBuilder.java +++ b/server/src/main/java/org/opensearch/index/query/BoolQueryBuilder.java @@ -342,8 +342,9 @@ protected Query doToQuery(QueryShardContext context) throws IOException { query = fixNegativeQueryIfNeeded(query); } - // TODO: Figure out why multi-clause breaks testPhrasePrefix() in HighlighterWithAnalyzersTests.java - if (query instanceof BooleanQuery boolQuery) { + // limit approximate query construction since several mappers (prefixQuery) expect a BooleanQuery not ApproximateBooleanQuery + if (query instanceof BooleanQuery boolQuery + && (boolQuery.getClauses(Occur.FILTER).size() == boolQuery.clauses().size() || boolQuery.clauses().size() == 1)) { return new ApproximateScoreQuery(query, new ApproximateBooleanQuery(boolQuery)); } diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java index ed9718e0dfef7..c6a59e1d67508 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java @@ -48,6 +48,11 @@ public BooleanQuery getBooleanQuery() { return boolQuery; } + @Override + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + return super.rewrite(indexSearcher); + } + @Override protected boolean canApproximate(SearchContext context) { if (context == null) { @@ -64,6 +69,11 @@ protected boolean canApproximate(SearchContext context) { return false; } + // Don't approximate if highlighting is enabled + if (context.highlight() != null) { + return false; + } + // For single clause boolean queries, check if the clause can be approximated if (clauses.size() == 1 && clauses.get(0).occur() != BooleanClause.Occur.MUST_NOT) { // If the clause is already an ApproximateScoreQuery, we can approximate + set context @@ -76,8 +86,19 @@ protected boolean canApproximate(SearchContext context) { return false; } - // return clauses.size() > 1 && clauses.stream().allMatch(clause -> clause.occur() == BooleanClause.Occur.FILTER); - return clauses.stream().allMatch(clause -> clause.occur() == BooleanClause.Occur.FILTER); + // multi clause case - we might want to consider strategies for nested cases, for now limit to just top level + for (BooleanClause clause : clauses) { + if (clause.occur() != BooleanClause.Occur.FILTER) { + return false; + } else { + if (clause.query() instanceof ApproximateScoreQuery appxScore + && appxScore.getApproximationQuery() instanceof ApproximateBooleanQuery) { + return false; + } + } + } + + return true; } @Override From 345a6ad84c7b1cbcd116c29ae51f837fbd4e83ce Mon Sep 17 00:00:00 2001 From: Sawan Srivastava Date: Thu, 14 Aug 2025 19:14:07 +0000 Subject: [PATCH 30/38] add unit tests Signed-off-by: Sawan Srivastava --- .../ApproximateBooleanQueryTests.java | 579 ++++++++++++++++++ 1 file changed, 579 insertions(+) create mode 100644 server/src/test/java/org/opensearch/search/approximate/ApproximateBooleanQueryTests.java diff --git a/server/src/test/java/org/opensearch/search/approximate/ApproximateBooleanQueryTests.java b/server/src/test/java/org/opensearch/search/approximate/ApproximateBooleanQueryTests.java new file mode 100644 index 0000000000000..6e213c4fe4f8f --- /dev/null +++ b/server/src/test/java/org/opensearch/search/approximate/ApproximateBooleanQueryTests.java @@ -0,0 +1,579 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.approximate; + +import org.apache.lucene.analysis.core.WhitespaceAnalyzer; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.IntPoint; +import org.apache.lucene.document.NumericDocValuesField; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.StoredFields; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.ScorerSupplier; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.Weight; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.opensearch.search.aggregations.SearchContextAggregations; +import org.opensearch.search.fetch.subphase.highlight.SearchHighlightContext; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.Arrays; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class ApproximateBooleanQueryTests extends OpenSearchTestCase { + + // Unit Tests for canApproximate method + public void testCanApproximateWithNullContext() { + BooleanQuery boolQuery = new BooleanQuery.Builder().add(IntPoint.newRangeQuery("field", 1, 100), BooleanClause.Occur.FILTER) + .build(); + ApproximateBooleanQuery query = new ApproximateBooleanQuery(boolQuery); + + assertFalse(query.canApproximate(null)); + } + + public void testCanApproximateWithAccurateTotalHits() { + BooleanQuery boolQuery = new BooleanQuery.Builder().add(IntPoint.newRangeQuery("field", 1, 100), BooleanClause.Occur.FILTER) + .build(); + ApproximateBooleanQuery query = new ApproximateBooleanQuery(boolQuery); + + SearchContext mockContext = mock(SearchContext.class); + when(mockContext.trackTotalHitsUpTo()).thenReturn(SearchContext.TRACK_TOTAL_HITS_ACCURATE); + + assertFalse(query.canApproximate(mockContext)); + } + + public void testCanApproximateWithAggregations() { + BooleanQuery boolQuery = new BooleanQuery.Builder().add(IntPoint.newRangeQuery("field", 1, 100), BooleanClause.Occur.FILTER) + .build(); + ApproximateBooleanQuery query = new ApproximateBooleanQuery(boolQuery); + + SearchContext mockContext = mock(SearchContext.class); + when(mockContext.trackTotalHitsUpTo()).thenReturn(10000); + when(mockContext.aggregations()).thenReturn(mock(SearchContextAggregations.class)); + + assertFalse(query.canApproximate(mockContext)); + } + + public void testCanApproximateWithHighlighting() { + BooleanQuery boolQuery = new BooleanQuery.Builder().add(IntPoint.newRangeQuery("field", 1, 100), BooleanClause.Occur.FILTER) + .build(); + ApproximateBooleanQuery query = new ApproximateBooleanQuery(boolQuery); + + SearchContext mockContext = mock(SearchContext.class); + when(mockContext.trackTotalHitsUpTo()).thenReturn(10000); + when(mockContext.aggregations()).thenReturn(null); + + SearchHighlightContext mockHighlight = mock(SearchHighlightContext.class); + when(mockHighlight.fields()).thenReturn(Arrays.asList(mock(SearchHighlightContext.Field.class))); + when(mockContext.highlight()).thenReturn(mockHighlight); + + assertFalse(query.canApproximate(mockContext)); + } + + public void testCanApproximateWithValidFilterClauses() { + BooleanQuery boolQuery = new BooleanQuery.Builder().add(IntPoint.newRangeQuery("field1", 1, 100), BooleanClause.Occur.FILTER) + .add(IntPoint.newRangeQuery("field2", 200, 300), BooleanClause.Occur.FILTER) + .build(); + ApproximateBooleanQuery query = new ApproximateBooleanQuery(boolQuery); + + SearchContext mockContext = mock(SearchContext.class); + when(mockContext.trackTotalHitsUpTo()).thenReturn(10000); + when(mockContext.aggregations()).thenReturn(null); + when(mockContext.highlight()).thenReturn(null); + + assertTrue(query.canApproximate(mockContext)); + } + + public void testCanApproximateWithMustNotClause() { + BooleanQuery boolQuery = new BooleanQuery.Builder().add(IntPoint.newRangeQuery("field1", 1, 100), BooleanClause.Occur.FILTER) + .add(IntPoint.newRangeQuery("field2", 200, 300), BooleanClause.Occur.MUST_NOT) + .build(); + ApproximateBooleanQuery query = new ApproximateBooleanQuery(boolQuery); + + SearchContext mockContext = mock(SearchContext.class); + when(mockContext.trackTotalHitsUpTo()).thenReturn(10000); + when(mockContext.aggregations()).thenReturn(null); + when(mockContext.highlight()).thenReturn(null); + + assertFalse(query.canApproximate(mockContext)); + } + + // Unit Tests for ScorerSupplier + public void testScorerSupplierCreation() throws IOException { + try (Directory directory = newDirectory()) { + try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) { + // Add test documents + for (int i = 0; i < 100; i++) { + Document doc = new Document(); + doc.add(new IntPoint("field1", i)); + doc.add(new IntPoint("field2", i * 2)); + doc.add(new NumericDocValuesField("field1", i)); + iw.addDocument(doc); + } + iw.flush(); + + try (IndexReader reader = iw.getReader()) { + IndexSearcher searcher = new IndexSearcher(reader); + LeafReaderContext leafContext = reader.leaves().get(0); + + BooleanQuery boolQuery = new BooleanQuery.Builder().add( + IntPoint.newRangeQuery("field1", 10, 50), + BooleanClause.Occur.FILTER + ).add(IntPoint.newRangeQuery("field2", 20, 100), BooleanClause.Occur.FILTER).build(); + ApproximateBooleanQuery query = new ApproximateBooleanQuery(boolQuery); + + Weight weight = query.createWeight(searcher, ScoreMode.COMPLETE, 1.0f); + ScorerSupplier supplier = weight.scorerSupplier(leafContext); + + assertNotNull(supplier); + assertTrue(supplier instanceof ApproximateBooleanScorerSupplier); + + // Test cost estimation + assertTrue(supplier.cost() > 0); + + // Test scorer creation + Scorer scorer = supplier.get(1000); + assertNotNull(scorer); + } + } + } + } + + // Test with single clause (nested ApproximateScoreQuery case) + public void testSingleClauseApproximation() { + ApproximatePointRangeQuery pointQuery = new ApproximatePointRangeQuery( + "field", + new byte[] { 1 }, + new byte[] { 100 }, + 1, + ApproximatePointRangeQuery.LONG_FORMAT + ); + ApproximateScoreQuery scoreQuery = new ApproximateScoreQuery(IntPoint.newRangeQuery("field", 1, 100), pointQuery); + + BooleanQuery boolQuery = new BooleanQuery.Builder().add(scoreQuery, BooleanClause.Occur.MUST).build(); + ApproximateBooleanQuery query = new ApproximateBooleanQuery(boolQuery); + + SearchContext mockContext = mock(SearchContext.class); + when(mockContext.trackTotalHitsUpTo()).thenReturn(10000); + when(mockContext.aggregations()).thenReturn(null); + when(mockContext.highlight()).thenReturn(null); + + // Should delegate to nested query's canApproximate + boolean result = query.canApproximate(mockContext); + assertTrue(result); + } + + public void testSingleClauseMustCanApproximate() { + ApproximateScoreQuery approxQuery = new ApproximateScoreQuery( + IntPoint.newRangeQuery("field", 1, 100), + new ApproximatePointRangeQuery("field", new byte[] { 1 }, new byte[] { 100 }, 1, ApproximatePointRangeQuery.INT_FORMAT) + ); + + BooleanQuery boolQuery = new BooleanQuery.Builder().add(approxQuery, BooleanClause.Occur.MUST).build(); + ApproximateBooleanQuery query = new ApproximateBooleanQuery(boolQuery); + + SearchContext mockContext = mock(SearchContext.class); + when(mockContext.trackTotalHitsUpTo()).thenReturn(10000); + when(mockContext.aggregations()).thenReturn(null); + when(mockContext.highlight()).thenReturn(null); + + // Single clause with MUST should be approximatable with ApproximateScoreQuery + assertTrue("Single MUST clause with ApproximateScoreQuery should be approximatable", query.canApproximate(mockContext)); + } + + public void testSingleClauseShouldCanApproximate() { + ApproximateScoreQuery approxQuery = new ApproximateScoreQuery( + IntPoint.newRangeQuery("field", 1, 100), + new ApproximatePointRangeQuery("field", new byte[] { 1 }, new byte[] { 100 }, 1, ApproximatePointRangeQuery.INT_FORMAT) + ); + + BooleanQuery boolQuery = new BooleanQuery.Builder().add(approxQuery, BooleanClause.Occur.SHOULD).build(); + ApproximateBooleanQuery query = new ApproximateBooleanQuery(boolQuery); + + SearchContext mockContext = mock(SearchContext.class); + when(mockContext.trackTotalHitsUpTo()).thenReturn(10000); + when(mockContext.aggregations()).thenReturn(null); + when(mockContext.highlight()).thenReturn(null); + + // Single clause with SHOULD should be approximatable with ApproximateScoreQuery + assertTrue("Single SHOULD clause with ApproximateScoreQuery should be approximatable", query.canApproximate(mockContext)); + } + + public void testSingleClauseFilterCanApproximate() { + ApproximateScoreQuery approxQuery = new ApproximateScoreQuery( + IntPoint.newRangeQuery("field", 1, 100), + new ApproximatePointRangeQuery("field", new byte[] { 1 }, new byte[] { 100 }, 1, ApproximatePointRangeQuery.INT_FORMAT) + ); + + BooleanQuery boolQuery = new BooleanQuery.Builder().add(approxQuery, BooleanClause.Occur.FILTER).build(); + ApproximateBooleanQuery query = new ApproximateBooleanQuery(boolQuery); + + SearchContext mockContext = mock(SearchContext.class); + when(mockContext.trackTotalHitsUpTo()).thenReturn(10000); + when(mockContext.aggregations()).thenReturn(null); + when(mockContext.highlight()).thenReturn(null); + + // Single clause with FILTER should be approximatable with ApproximateScoreQuery + assertTrue("Single FILTER clause with ApproximateScoreQuery should be approximatable", query.canApproximate(mockContext)); + } + + // Test BoolQueryBuilder pattern: Single clause WITH ApproximateScoreQuery wrapper + public void testSingleClauseWithApproximateScoreQueryCanApproximate() { + // Create ApproximateScoreQuery wrapper (as BoolQueryBuilder would) + ApproximatePointRangeQuery approxQuery = new ApproximatePointRangeQuery( + "field", + new byte[] { 1 }, + new byte[] { 100 }, + 1, + ApproximatePointRangeQuery.LONG_FORMAT + ); + ApproximateScoreQuery scoreQuery = new ApproximateScoreQuery(IntPoint.newRangeQuery("field", 1, 100), approxQuery); + + // Test all single clause types (MUST, SHOULD, FILTER) - all should work + BooleanClause.Occur[] occurs = { BooleanClause.Occur.MUST, BooleanClause.Occur.SHOULD, BooleanClause.Occur.FILTER }; + + for (BooleanClause.Occur occur : occurs) { + BooleanQuery boolQuery = new BooleanQuery.Builder().add(scoreQuery, occur).build(); + ApproximateBooleanQuery query = new ApproximateBooleanQuery(boolQuery); + + SearchContext mockContext = mock(SearchContext.class); + when(mockContext.trackTotalHitsUpTo()).thenReturn(10000); + when(mockContext.aggregations()).thenReturn(null); + when(mockContext.highlight()).thenReturn(null); + + // Single clause with ApproximateScoreQuery should delegate to nested query + boolean result = query.canApproximate(mockContext); + assertTrue("Single " + occur + " clause with ApproximateScoreQuery should be approximatable", result); + } + } + + // Test single MUST_NOT clause should NOT be approximatable + public void testSingleClauseMustNotCannotApproximate() { + ApproximatePointRangeQuery approxQuery = new ApproximatePointRangeQuery( + "field", + new byte[] { 1 }, + new byte[] { 100 }, + 1, + ApproximatePointRangeQuery.LONG_FORMAT + ); + ApproximateScoreQuery scoreQuery = new ApproximateScoreQuery(IntPoint.newRangeQuery("field", 1, 100), approxQuery); + + BooleanQuery boolQuery = new BooleanQuery.Builder().add(scoreQuery, BooleanClause.Occur.MUST_NOT).build(); + ApproximateBooleanQuery query = new ApproximateBooleanQuery(boolQuery); + + SearchContext mockContext = mock(SearchContext.class); + when(mockContext.trackTotalHitsUpTo()).thenReturn(10000); + when(mockContext.aggregations()).thenReturn(null); + when(mockContext.highlight()).thenReturn(null); + + // Single MUST_NOT clause should be blocked + assertFalse("Single MUST_NOT clause should not be approximatable", query.canApproximate(mockContext)); + } + + public void testNestedSingleClauseWithApproximateScoreQuery() { + // Create inner ApproximateScoreQuery manually (verbose version) + ApproximatePointRangeQuery innerApproxQuery = new ApproximatePointRangeQuery( + "inner_field", + new byte[] { 50 }, + new byte[] { (byte) 150 }, + 1, + ApproximatePointRangeQuery.INT_FORMAT + ); + ApproximateScoreQuery innerScoreQuery = new ApproximateScoreQuery(IntPoint.newRangeQuery("inner_field", 50, 150), innerApproxQuery); + + // Inner boolean query (single clause) + BooleanQuery innerBoolQuery = new BooleanQuery.Builder().add(innerScoreQuery, BooleanClause.Occur.FILTER).build(); + + ApproximateBooleanQuery innerApproxBoolQuery = new ApproximateBooleanQuery(innerBoolQuery); + ApproximateScoreQuery outerScoreQuery = new ApproximateScoreQuery(innerBoolQuery, innerApproxBoolQuery); + + // Outer boolean query (single clause containing nested) + BooleanQuery outerBoolQuery = new BooleanQuery.Builder().add(outerScoreQuery, BooleanClause.Occur.MUST).build(); + ApproximateBooleanQuery outerQuery = new ApproximateBooleanQuery(outerBoolQuery); + + SearchContext mockContext = mock(SearchContext.class); + when(mockContext.trackTotalHitsUpTo()).thenReturn(10000); + when(mockContext.aggregations()).thenReturn(null); + when(mockContext.highlight()).thenReturn(null); + + // Should delegate to nested ApproximateBooleanQuery + boolean result = outerQuery.canApproximate(mockContext); + assertTrue("Nested single clause should follow inner query logic and be approximatable", result); + } + + // Test nested boolean query with ApproximateScoreQuery wrapper (multi-clause pattern) + public void testNestedMultiClauseWithApproximateScoreQuery() { + // Create inner ApproximateScoreQuery instances manually + ApproximateScoreQuery innerQuery1 = new ApproximateScoreQuery( + IntPoint.newRangeQuery("inner_field1", 50, 150), + new ApproximatePointRangeQuery( + "inner_field1", + new byte[] { 50 }, + new byte[] { (byte) 150 }, + 1, + ApproximatePointRangeQuery.INT_FORMAT + ) + ); + ApproximateScoreQuery innerQuery2 = new ApproximateScoreQuery( + IntPoint.newRangeQuery("inner_field2", 200, 300), + new ApproximatePointRangeQuery( + "inner_field2", + new byte[] { (byte) 200 }, + new byte[] { (byte) 300 }, + 1, + ApproximatePointRangeQuery.INT_FORMAT + ) + ); + + // Inner boolean query (all FILTER clauses) + BooleanQuery innerBoolQuery = new BooleanQuery.Builder().add(innerQuery1, BooleanClause.Occur.FILTER) + .add(innerQuery2, BooleanClause.Occur.FILTER) + .build(); + + ApproximateBooleanQuery innerApproxQuery = new ApproximateBooleanQuery(innerBoolQuery); + ApproximateScoreQuery scoreQuery = new ApproximateScoreQuery(innerBoolQuery, innerApproxQuery); + + // Create outer ApproximateScoreQuery manually + ApproximateScoreQuery outerFieldQuery = new ApproximateScoreQuery( + IntPoint.newRangeQuery("outer_field", 1, 100), + new ApproximatePointRangeQuery("outer_field", new byte[] { 1 }, new byte[] { 100 }, 1, ApproximatePointRangeQuery.INT_FORMAT) + ); + + // Outer boolean query (multi-clause with nested) + BooleanQuery outerBoolQuery = new BooleanQuery.Builder().add(scoreQuery, BooleanClause.Occur.FILTER) + .add(outerFieldQuery, BooleanClause.Occur.FILTER) + .build(); + ApproximateBooleanQuery outerQuery = new ApproximateBooleanQuery(outerBoolQuery); + + SearchContext mockContext = mock(SearchContext.class); + when(mockContext.trackTotalHitsUpTo()).thenReturn(10000); + when(mockContext.aggregations()).thenReturn(null); + when(mockContext.highlight()).thenReturn(null); + + // Should delegate to nested ApproximateBooleanQuery and return true + assertFalse("Nested multi-FILTER clause should not be approximatable", outerQuery.canApproximate(mockContext)); + } + + // Test mixed clause types (should not be approximatable) + public void testMixedClauseTypesCannotApproximate() { + BooleanQuery boolQuery = new BooleanQuery.Builder().add(IntPoint.newRangeQuery("field1", 1, 100), BooleanClause.Occur.FILTER) + .add(IntPoint.newRangeQuery("field2", 200, 300), BooleanClause.Occur.MUST) + .add(IntPoint.newRangeQuery("field3", 400, 500), BooleanClause.Occur.SHOULD) + .build(); + ApproximateBooleanQuery query = new ApproximateBooleanQuery(boolQuery); + + SearchContext mockContext = mock(SearchContext.class); + when(mockContext.trackTotalHitsUpTo()).thenReturn(10000); + when(mockContext.aggregations()).thenReturn(null); + when(mockContext.highlight()).thenReturn(null); + + assertFalse("Mixed clause types should not be approximatable", query.canApproximate(mockContext)); + } + + // Test deeply nested boolean queries + public void testDeeplyNestedBooleanQueries() { + // Level 3 (deepest) - Create ApproximateScoreQuery manually + ApproximateScoreQuery deep1Query = new ApproximateScoreQuery( + IntPoint.newRangeQuery("deep_field1", 1, 50), + new ApproximatePointRangeQuery("deep_field1", new byte[] { 1 }, new byte[] { 50 }, 1, ApproximatePointRangeQuery.INT_FORMAT) + ); + ApproximateScoreQuery deep2Query = new ApproximateScoreQuery( + IntPoint.newRangeQuery("deep_field2", 51, 100), + new ApproximatePointRangeQuery("deep_field2", new byte[] { 51 }, new byte[] { 100 }, 1, ApproximatePointRangeQuery.INT_FORMAT) + ); + + BooleanQuery level3Query = new BooleanQuery.Builder().add(deep1Query, BooleanClause.Occur.FILTER) + .add(deep2Query, BooleanClause.Occur.FILTER) + .build(); + ApproximateBooleanQuery level3Approx = new ApproximateBooleanQuery(level3Query); + ApproximateScoreQuery level3Score = new ApproximateScoreQuery(level3Query, level3Approx); + + // Level 2 (middle) + ApproximateScoreQuery midQuery = new ApproximateScoreQuery( + IntPoint.newRangeQuery("mid_field", 200, 300), + new ApproximatePointRangeQuery( + "mid_field", + new byte[] { (byte) 200 }, + new byte[] { (byte) 300 }, + 1, + ApproximatePointRangeQuery.INT_FORMAT + ) + ); + + BooleanQuery level2Query = new BooleanQuery.Builder().add(level3Score, BooleanClause.Occur.FILTER) + .add(midQuery, BooleanClause.Occur.FILTER) + .build(); + ApproximateBooleanQuery level2Approx = new ApproximateBooleanQuery(level2Query); + ApproximateScoreQuery level2Score = new ApproximateScoreQuery(level2Query, level2Approx); + + // Level 1 (top) + ApproximateScoreQuery topFieldQuery = new ApproximateScoreQuery( + IntPoint.newRangeQuery("top_field", 400, 500), + new ApproximatePointRangeQuery( + "top_field", + new byte[] { (byte) 400 }, + new byte[] { (byte) 500 }, + 1, + ApproximatePointRangeQuery.INT_FORMAT + ) + ); + + BooleanQuery level1Query = new BooleanQuery.Builder().add(level2Score, BooleanClause.Occur.FILTER) + .add(topFieldQuery, BooleanClause.Occur.FILTER) + .build(); + ApproximateBooleanQuery topQuery = new ApproximateBooleanQuery(level1Query); + + SearchContext mockContext = mock(SearchContext.class); + when(mockContext.trackTotalHitsUpTo()).thenReturn(10000); + when(mockContext.aggregations()).thenReturn(null); + when(mockContext.highlight()).thenReturn(null); + + assertFalse("Deeply nested all-FILTER queries should not be approximatable", topQuery.canApproximate(mockContext)); + } + + // Test edge case: nested query with highlighting should be blocked + public void testNestedQueryWithHighlightingBlocked() { + // Inner boolean query (all FILTER clauses) + BooleanQuery innerBoolQuery = new BooleanQuery.Builder().add( + IntPoint.newRangeQuery("inner_field1", 50, 150), + BooleanClause.Occur.FILTER + ).add(IntPoint.newRangeQuery("inner_field2", 200, 300), BooleanClause.Occur.FILTER).build(); + + ApproximateBooleanQuery innerApproxQuery = new ApproximateBooleanQuery(innerBoolQuery); + ApproximateScoreQuery scoreQuery = new ApproximateScoreQuery(innerBoolQuery, innerApproxQuery); + + // Outer boolean query + BooleanQuery outerBoolQuery = new BooleanQuery.Builder().add(scoreQuery, BooleanClause.Occur.FILTER).build(); + ApproximateBooleanQuery outerQuery = new ApproximateBooleanQuery(outerBoolQuery); + + SearchContext mockContext = mock(SearchContext.class); + when(mockContext.trackTotalHitsUpTo()).thenReturn(10000); + when(mockContext.aggregations()).thenReturn(null); + + // Add highlighting + SearchHighlightContext mockHighlight = mock(SearchHighlightContext.class); + when(mockHighlight.fields()).thenReturn(Arrays.asList(mock(SearchHighlightContext.Field.class))); + when(mockContext.highlight()).thenReturn(mockHighlight); + + assertFalse("Nested queries with highlighting should be blocked", outerQuery.canApproximate(mockContext)); + } + + // Test edge case: nested query with one level having MUST_NOT + public void testNestedQueryWithMustNotClause() { + // Inner boolean query (contains MUST_NOT) + BooleanQuery innerBoolQuery = new BooleanQuery.Builder().add( + IntPoint.newRangeQuery("inner_field1", 50, 150), + BooleanClause.Occur.FILTER + ).add(IntPoint.newRangeQuery("inner_field2", 200, 300), BooleanClause.Occur.MUST_NOT).build(); + + ApproximateBooleanQuery innerApproxQuery = new ApproximateBooleanQuery(innerBoolQuery); + ApproximateScoreQuery scoreQuery = new ApproximateScoreQuery(innerBoolQuery, innerApproxQuery); + + // Outer boolean query (all FILTER) + BooleanQuery outerBoolQuery = new BooleanQuery.Builder().add(scoreQuery, BooleanClause.Occur.FILTER) + .add(IntPoint.newRangeQuery("outer_field", 1, 100), BooleanClause.Occur.FILTER) + .build(); + ApproximateBooleanQuery outerQuery = new ApproximateBooleanQuery(outerBoolQuery); + + SearchContext mockContext = mock(SearchContext.class); + when(mockContext.trackTotalHitsUpTo()).thenReturn(10000); + when(mockContext.aggregations()).thenReturn(null); + when(mockContext.highlight()).thenReturn(null); + + // Should be blocked by inner MUST_NOT clause + assertFalse("Nested query with MUST_NOT should not be approximatable", outerQuery.canApproximate(mockContext)); + } + + // Test window size heuristic with different cost scenarios + public void testWindowSizeHeuristic() throws IOException { + try (Directory directory = newDirectory()) { + try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) { + for (int i = 0; i < 1000; i++) { + Document doc = new Document(); + doc.add(new IntPoint("field1", i)); + doc.add(new IntPoint("field2", i * 2)); + iw.addDocument(doc); + } + iw.flush(); + + try (IndexReader reader = iw.getReader()) { + IndexSearcher searcher = new IndexSearcher(reader); + LeafReaderContext leafContext = reader.leaves().get(0); + + BooleanQuery boolQuery = new BooleanQuery.Builder().add( + IntPoint.newRangeQuery("field1", 100, 900), + BooleanClause.Occur.FILTER + ).add(IntPoint.newRangeQuery("field2", 200, 1800), BooleanClause.Occur.FILTER).build(); + ApproximateBooleanQuery query = new ApproximateBooleanQuery(boolQuery); + + Weight weight = query.createWeight(searcher, ScoreMode.COMPLETE, 1.0f); + ApproximateBooleanScorerSupplier supplier = (ApproximateBooleanScorerSupplier) weight.scorerSupplier(leafContext); + + assertNotNull(supplier); + + // Test that cost calculation works + long cost = supplier.cost(); + assertTrue("Cost should be positive", cost > 0); + } + } + } + } + + public void testApproximateQueryValidation( + IndexSearcher searcher, + String field1, + String field2, + int lower1, + int upper1, + int lower2, + int upper2, + int size + ) throws IOException { + // Test with approximate query + BooleanQuery boolQuery = new BooleanQuery.Builder().add(IntPoint.newRangeQuery(field1, lower1, upper1), BooleanClause.Occur.FILTER) + .add(IntPoint.newRangeQuery(field2, lower2, upper2), BooleanClause.Occur.FILTER) + .build(); + ApproximateScoreQuery approxQuery = new ApproximateScoreQuery(boolQuery, new ApproximateBooleanQuery(boolQuery)); + + TopDocs approxDocs = searcher.search(approxQuery, size); + + // Validate hit count + assertTrue("Approximate query should return at most " + size + " docs", approxDocs.scoreDocs.length <= size); + assertTrue("Should not exceed 10k hits", approxDocs.totalHits.value() <= 10000); + + // Validate hit accuracy - each returned doc should match the query criteria + StoredFields storedFields = searcher.getIndexReader().storedFields(); + for (int i = 0; i < approxDocs.scoreDocs.length; i++) { + int docId = approxDocs.scoreDocs[i].doc; + Document doc = storedFields.document(docId); + + int field1Value = doc.getField(field1).numericValue().intValue(); + int field2Value = doc.getField(field2).numericValue().intValue(); + + assertTrue( + field1 + " should be in range [" + lower1 + ", " + upper1 + "], got: " + field1Value, + field1Value >= lower1 && field1Value <= upper1 + ); + assertTrue( + field2 + " should be in range [" + lower2 + ", " + upper2 + "], got: " + field2Value, + field2Value >= lower2 && field2Value <= upper2 + ); + } + } +} From c162ddf30d4602ce720e414be7a72961c739b2a2 Mon Sep 17 00:00:00 2001 From: Sawan Srivastava Date: Thu, 14 Aug 2025 23:31:10 +0000 Subject: [PATCH 31/38] fix more tests Signed-off-by: Sawan Srivastava --- .../approximate/ApproximateBooleanQuery.java | 29 ++++++++++++++++--- .../approximate/ApproximateScoreQuery.java | 4 +++ 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java index c6a59e1d67508..4b25e23c13778 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java @@ -53,6 +53,21 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { return super.rewrite(indexSearcher); } + public static Query boolRewrite(Query query, IndexSearcher indexSearcher) { + if (query instanceof BooleanQuery boolQuery) { + return (boolQuery.clauses().size() == 1) ? boolRewrite(boolQuery.clauses().get(0).query(), indexSearcher) : query; + } else if (query instanceof ApproximateBooleanQuery appxBool) { + return (appxBool.getBooleanQuery().clauses().size() == 1) + ? boolRewrite(appxBool.boolQuery.clauses().get(0).query(), indexSearcher) + : query; + } + try { + return query.rewrite(indexSearcher); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + @Override protected boolean canApproximate(SearchContext context) { if (context == null) { @@ -86,19 +101,25 @@ protected boolean canApproximate(SearchContext context) { return false; } + boolean hasApproximate = false; + // multi clause case - we might want to consider strategies for nested cases, for now limit to just top level for (BooleanClause clause : clauses) { if (clause.occur() != BooleanClause.Occur.FILTER) { return false; } else { - if (clause.query() instanceof ApproximateScoreQuery appxScore - && appxScore.getApproximationQuery() instanceof ApproximateBooleanQuery) { - return false; + if (clause.query() instanceof ApproximateScoreQuery appxScore) { + if (appxScore.getApproximationQuery() instanceof ApproximatePointRangeQuery) { + hasApproximate = true; + } + if (appxScore.getApproximationQuery() instanceof ApproximateBooleanQuery || clause.query() instanceof BooleanQuery) { + return false; + } } } } - return true; + return hasApproximate; } @Override diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java index 6ac8f1dc6afc0..1a8e23a035f27 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java @@ -9,6 +9,7 @@ package org.opensearch.search.approximate; import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; @@ -65,6 +66,9 @@ public void setContext(SearchContext context) { } } } + if ((resolvedQuery instanceof BooleanQuery) || (resolvedQuery instanceof ApproximateBooleanQuery)) { + resolvedQuery = ApproximateBooleanQuery.boolRewrite(resolvedQuery, context.searcher()); + } } @Override From b7b924d25008c6b210a4d226edabc4c3614fc439 Mon Sep 17 00:00:00 2001 From: Sawan Srivastava Date: Mon, 18 Aug 2025 21:04:49 +0000 Subject: [PATCH 32/38] partially working integ tests Signed-off-by: Sawan Srivastava --- .../ApproximateBooleanScorerSupplier.java | 12 + .../ApproximatePointRangeQuery.java | 24 +- .../ApproximateBooleanQueryTests.java | 669 +++++++++++++++++- 3 files changed, 665 insertions(+), 40 deletions(-) diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java index 9bc9784035254..c692b33c9d681 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java @@ -198,6 +198,10 @@ private List rebuildIteratorsWithWindowSize(int windowSize) th approxQuery.setSize(windowSize); try { Scorer scorer = supplier.get(windowSize); + if (scorer == null) { + // Clause is fully traversed, end entire conjunction + return null; + } newIterators.add(scorer.iterator()); } finally { // Restore original size @@ -222,6 +226,10 @@ public int score(LeafCollector collector, Bits acceptDocs, int min, int max) thr // Rebuild iterators with new window size List newIterators = rebuildIteratorsWithWindowSize(currentWindowSize); + if (newIterators == null) { + // A clause is fully traversed, end conjunction + return DocIdSetIterator.NO_MORE_DOCS; + } globalConjunction = ConjunctionUtils.intersectIterators(newIterators); // Return first docID from new conjunction (could be < min) @@ -276,6 +284,10 @@ private int scoreExistingConjunction(LeafCollector collector, Bits acceptDocs, i try { List newIterators = rebuildIteratorsWithWindowSize(currentWindowSize); + if (newIterators == null) { + // A clause is fully traversed, end conjunction + return DocIdSetIterator.NO_MORE_DOCS; + } globalConjunction = ConjunctionUtils.intersectIterators(newIterators); int firstDoc = globalConjunction.nextDoc(); diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java index 73662d8bbaa6d..4aef6c0336012 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java @@ -400,7 +400,29 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti } // values.size(): total points indexed, In most cases: values.size() ≈ number of documents (assuming single-valued fields) if (size > values.size()) { - return pointRangeQueryWeight.scorerSupplier(context); + ScorerSupplier luceneSupplier = pointRangeQueryWeight.scorerSupplier(context); + return new ScorerSupplier() { + boolean alreadyFullyTraversed = false; + + @Override + public Scorer get(long leadCost) throws IOException { + return getWithSize(size); + } + + public Scorer getWithSize(int dynamicSize) throws IOException { + if (alreadyFullyTraversed) { + return null; // Signal end of conjunction + } + + alreadyFullyTraversed = true; + return luceneSupplier.get(Long.MAX_VALUE); + } + + @Override + public long cost() { + return luceneSupplier.cost(); + } + }; } else { if (sortOrder == null || sortOrder.equals(SortOrder.ASC)) { return new ScorerSupplier() { diff --git a/server/src/test/java/org/opensearch/search/approximate/ApproximateBooleanQueryTests.java b/server/src/test/java/org/opensearch/search/approximate/ApproximateBooleanQueryTests.java index 6e213c4fe4f8f..07a1b7929896a 100644 --- a/server/src/test/java/org/opensearch/search/approximate/ApproximateBooleanQueryTests.java +++ b/server/src/test/java/org/opensearch/search/approximate/ApproximateBooleanQueryTests.java @@ -12,27 +12,42 @@ import org.apache.lucene.document.Document; import org.apache.lucene.document.IntPoint; import org.apache.lucene.document.NumericDocValuesField; +import org.apache.lucene.document.StoredField; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.StoredFields; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.BulkScorer; import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.LeafCollector; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.Scorable; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.ScorerSupplier; import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TopScoreDocCollector; +import org.apache.lucene.search.TopScoreDocCollectorManager; import org.apache.lucene.search.Weight; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; +import org.opensearch.index.shard.IndexShard; +import org.opensearch.index.shard.SearchOperationListener; +import org.opensearch.search.aggregations.BucketCollectorProcessor; import org.opensearch.search.aggregations.SearchContextAggregations; import org.opensearch.search.fetch.subphase.highlight.SearchHighlightContext; +import org.opensearch.search.internal.ContextIndexSearcher; import org.opensearch.search.internal.SearchContext; import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; +import java.util.concurrent.ExecutorService; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -87,8 +102,29 @@ public void testCanApproximateWithHighlighting() { } public void testCanApproximateWithValidFilterClauses() { - BooleanQuery boolQuery = new BooleanQuery.Builder().add(IntPoint.newRangeQuery("field1", 1, 100), BooleanClause.Occur.FILTER) - .add(IntPoint.newRangeQuery("field2", 200, 300), BooleanClause.Occur.FILTER) + ApproximateScoreQuery approxQuery1 = new ApproximateScoreQuery( + IntPoint.newRangeQuery("field1", 1, 100), + new ApproximatePointRangeQuery( + "field1", + IntPoint.pack(new int[] { 1 }).bytes, + IntPoint.pack(new int[] { 100 }).bytes, + 1, + ApproximatePointRangeQuery.INT_FORMAT + ) + ); + ApproximateScoreQuery approxQuery2 = new ApproximateScoreQuery( + IntPoint.newRangeQuery("field2", 200, 300), + new ApproximatePointRangeQuery( + "field2", + IntPoint.pack(new int[] { 200 }).bytes, + IntPoint.pack(new int[] { 300 }).bytes, + 1, + ApproximatePointRangeQuery.INT_FORMAT + ) + ); + + BooleanQuery boolQuery = new BooleanQuery.Builder().add(approxQuery1, BooleanClause.Occur.FILTER) + .add(approxQuery2, BooleanClause.Occur.FILTER) .build(); ApproximateBooleanQuery query = new ApproximateBooleanQuery(boolQuery); @@ -97,6 +133,9 @@ public void testCanApproximateWithValidFilterClauses() { when(mockContext.aggregations()).thenReturn(null); when(mockContext.highlight()).thenReturn(null); + approxQuery1.setContext(mockContext); + approxQuery2.setContext(mockContext); + assertTrue(query.canApproximate(mockContext)); } @@ -119,7 +158,7 @@ public void testScorerSupplierCreation() throws IOException { try (Directory directory = newDirectory()) { try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) { // Add test documents - for (int i = 0; i < 100; i++) { + for (int i = 0; i < 20000; i++) { Document doc = new Document(); doc.add(new IntPoint("field1", i)); doc.add(new IntPoint("field2", i * 2)); @@ -155,12 +194,137 @@ public void testScorerSupplierCreation() throws IOException { } } + // Integration test comparing approximate vs exact results + public void testApproximateVsExactResults() throws IOException { + try (Directory directory = newDirectory()) { + try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) { + int numDocs = 12000; + for (int i = 0; i < numDocs; i++) { + Document doc = new Document(); + doc.add(new IntPoint("field1", i % 1000)); + doc.add(new IntPoint("field2", (i * 3) % 1000)); + doc.add(new NumericDocValuesField("field1", i)); + iw.addDocument(doc); + } + iw.flush(); + + try (IndexReader reader = iw.getReader()) { + IndexSearcher searcher = new IndexSearcher(reader); + + int lower1 = 200; + int upper1 = 400; + int lower2 = 300; + int upper2 = 500; + + // Create approximate query + ApproximatePointRangeQuery approxQuery1 = new ApproximatePointRangeQuery( + "field1", + IntPoint.pack(new int[] { lower1 }).bytes, + IntPoint.pack(new int[] { upper1 }).bytes, + 1, + ApproximatePointRangeQuery.INT_FORMAT + ); + ApproximatePointRangeQuery approxQuery2 = new ApproximatePointRangeQuery( + "field2", + IntPoint.pack(new int[] { lower2 }).bytes, + IntPoint.pack(new int[] { upper2 }).bytes, + 1, + ApproximatePointRangeQuery.INT_FORMAT + ); + + // ApproximateScoreQuery approxQuery1 = new ApproximateScoreQuery(IntPoint.newRangeQuery("field1", lower1, upper1), new + // ApproximatePointRangeQuery("field1", IntPoint.pack(new int[]{lower1}).bytes, IntPoint.pack(new int[]{upper1}).bytes, + // 1, ApproximatePointRangeQuery.INT_FORMAT)); + // ApproximateScoreQuery approxQuery2 = new ApproximateScoreQuery(IntPoint.newRangeQuery("field2", lower2, upper2), new + // ApproximatePointRangeQuery("field2", IntPoint.pack(new int[]{lower2}).bytes, IntPoint.pack(new int[]{upper2}).bytes, + // 1, ApproximatePointRangeQuery.INT_FORMAT)); + + BooleanQuery boolQuery = new BooleanQuery.Builder().add(approxQuery1, BooleanClause.Occur.FILTER) + .add(approxQuery2, BooleanClause.Occur.FILTER) + .build(); + + ApproximateBooleanQuery approximateQuery = new ApproximateBooleanQuery(boolQuery); + + // Create exact query (same boolean structure) + Query exactQuery = boolQuery; + + // Search with both queries + TopDocs approximateDocs = searcher.search(approximateQuery, 1000); + TopDocs exactDocs = searcher.search(exactQuery, 1000); + + // Results should be identical when approximation is not triggered + // or when we collect all available documents + if (exactDocs.totalHits.value() <= 1000) { + assertEquals( + "Approximate and exact should return same number of docs when under limit", + exactDocs.totalHits.value(), + approximateDocs.totalHits.value() + ); + } + } + } + } + } + + // Test early termination at 10k hits + public void testEarlyTerminationAt10k() throws IOException { + try (Directory directory = newDirectory()) { + try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) { + // Create enough documents to exceed 10k hits + for (int i = 0; i < 20000; i++) { + Document doc = new Document(); + doc.add(new IntPoint("field1", i % 100)); // High overlap + doc.add(new IntPoint("field2", i % 50)); // High overlap + iw.addDocument(doc); + } + iw.flush(); + + try (IndexReader reader = iw.getReader()) { + IndexSearcher searcher = new IndexSearcher(reader); + + // Create query that should match many documents + + ApproximateScoreQuery approxQuery1 = new ApproximateScoreQuery( + IntPoint.newRangeQuery("field1", 0, 99), + new ApproximatePointRangeQuery( + "field1", + IntPoint.pack(new int[] { 0 }).bytes, + IntPoint.pack(new int[] { 99 }).bytes, + 1, + ApproximatePointRangeQuery.INT_FORMAT + ) + ); + ApproximateScoreQuery approxQuery2 = new ApproximateScoreQuery( + IntPoint.newRangeQuery("field2", 0, 49), + new ApproximatePointRangeQuery( + "field2", + IntPoint.pack(new int[] { 0 }).bytes, + IntPoint.pack(new int[] { 49 }).bytes, + 1, + ApproximatePointRangeQuery.INT_FORMAT + ) + ); + + BooleanQuery boolQuery = new BooleanQuery.Builder().add(approxQuery1, BooleanClause.Occur.FILTER) + .add(approxQuery2, BooleanClause.Occur.FILTER) + .build(); + ApproximateBooleanQuery query = new ApproximateBooleanQuery(boolQuery); + + TopDocs docs = searcher.search(query, 15000); + + // Should terminate early at exactly 10k hits + assertEquals("Should collect exactly 10k documents", 10000, docs.totalHits.value()); + } + } + } + } + // Test with single clause (nested ApproximateScoreQuery case) public void testSingleClauseApproximation() { ApproximatePointRangeQuery pointQuery = new ApproximatePointRangeQuery( "field", - new byte[] { 1 }, - new byte[] { 100 }, + IntPoint.pack(new int[] { 1 }).bytes, + IntPoint.pack(new int[] { 100 }).bytes, 1, ApproximatePointRangeQuery.LONG_FORMAT ); @@ -174,18 +338,67 @@ public void testSingleClauseApproximation() { when(mockContext.aggregations()).thenReturn(null); when(mockContext.highlight()).thenReturn(null); + scoreQuery.setContext(mockContext); + // Should delegate to nested query's canApproximate boolean result = query.canApproximate(mockContext); assertTrue(result); } - public void testSingleClauseMustCanApproximate() { - ApproximateScoreQuery approxQuery = new ApproximateScoreQuery( - IntPoint.newRangeQuery("field", 1, 100), - new ApproximatePointRangeQuery("field", new byte[] { 1 }, new byte[] { 100 }, 1, ApproximatePointRangeQuery.INT_FORMAT) + // Test BoolQueryBuilder pattern: All FILTER clauses (multi-clause) + public void testAllFilterClausesCanApproximate() { + // Create approximatable range queries manually + ApproximateScoreQuery approxQuery1 = new ApproximateScoreQuery( + IntPoint.newRangeQuery("field1", 1, 100), + new ApproximatePointRangeQuery( + "field1", + IntPoint.pack(new int[] { 1 }).bytes, + IntPoint.pack(new int[] { 100 }).bytes, + 1, + ApproximatePointRangeQuery.INT_FORMAT + ) + ); + ApproximateScoreQuery approxQuery2 = new ApproximateScoreQuery( + IntPoint.newRangeQuery("field2", 200, 300), + new ApproximatePointRangeQuery( + "field2", + IntPoint.pack(new int[] { 200 }).bytes, + IntPoint.pack(new int[] { 300 }).bytes, + 1, + ApproximatePointRangeQuery.INT_FORMAT + ) + ); + ApproximateScoreQuery approxQuery3 = new ApproximateScoreQuery( + IntPoint.newRangeQuery("field3", 400, 500), + new ApproximatePointRangeQuery( + "field3", + IntPoint.pack(new int[] { 400 }).bytes, + IntPoint.pack(new int[] { 500 }).bytes, + 1, + ApproximatePointRangeQuery.INT_FORMAT + ) ); - BooleanQuery boolQuery = new BooleanQuery.Builder().add(approxQuery, BooleanClause.Occur.MUST).build(); + BooleanQuery boolQuery = new BooleanQuery.Builder().add(approxQuery1, BooleanClause.Occur.FILTER) + .add(approxQuery2, BooleanClause.Occur.FILTER) + .add(approxQuery3, BooleanClause.Occur.FILTER) + .build(); + ApproximateBooleanQuery query = new ApproximateBooleanQuery(boolQuery); + + SearchContext mockContext = mock(SearchContext.class); + when(mockContext.trackTotalHitsUpTo()).thenReturn(10000); + when(mockContext.aggregations()).thenReturn(null); + when(mockContext.highlight()).thenReturn(null); + + approxQuery1.setContext(mockContext); + approxQuery2.setContext(mockContext); + approxQuery3.setContext(mockContext); + + assertTrue("All FILTER clauses should be approximatable", query.canApproximate(mockContext)); + } + + public void testSingleClauseMustCanApproximate() { + BooleanQuery boolQuery = new BooleanQuery.Builder().add(IntPoint.newRangeQuery("field", 1, 100), BooleanClause.Occur.MUST).build(); ApproximateBooleanQuery query = new ApproximateBooleanQuery(boolQuery); SearchContext mockContext = mock(SearchContext.class); @@ -193,8 +406,8 @@ public void testSingleClauseMustCanApproximate() { when(mockContext.aggregations()).thenReturn(null); when(mockContext.highlight()).thenReturn(null); - // Single clause with MUST should be approximatable with ApproximateScoreQuery - assertTrue("Single MUST clause with ApproximateScoreQuery should be approximatable", query.canApproximate(mockContext)); + // Single clause with MUST should return false (not handled by current logic) + assertFalse("Single MUST clause should not be approximatable", query.canApproximate(mockContext)); } public void testSingleClauseShouldCanApproximate() { @@ -211,17 +424,15 @@ public void testSingleClauseShouldCanApproximate() { when(mockContext.aggregations()).thenReturn(null); when(mockContext.highlight()).thenReturn(null); + approxQuery.setContext(mockContext); + // Single clause with SHOULD should be approximatable with ApproximateScoreQuery assertTrue("Single SHOULD clause with ApproximateScoreQuery should be approximatable", query.canApproximate(mockContext)); } public void testSingleClauseFilterCanApproximate() { - ApproximateScoreQuery approxQuery = new ApproximateScoreQuery( - IntPoint.newRangeQuery("field", 1, 100), - new ApproximatePointRangeQuery("field", new byte[] { 1 }, new byte[] { 100 }, 1, ApproximatePointRangeQuery.INT_FORMAT) - ); - - BooleanQuery boolQuery = new BooleanQuery.Builder().add(approxQuery, BooleanClause.Occur.FILTER).build(); + BooleanQuery boolQuery = new BooleanQuery.Builder().add(IntPoint.newRangeQuery("field", 1, 100), BooleanClause.Occur.FILTER) + .build(); ApproximateBooleanQuery query = new ApproximateBooleanQuery(boolQuery); SearchContext mockContext = mock(SearchContext.class); @@ -229,8 +440,8 @@ public void testSingleClauseFilterCanApproximate() { when(mockContext.aggregations()).thenReturn(null); when(mockContext.highlight()).thenReturn(null); - // Single clause with FILTER should be approximatable with ApproximateScoreQuery - assertTrue("Single FILTER clause with ApproximateScoreQuery should be approximatable", query.canApproximate(mockContext)); + // Single clause with FILTER should return false (not MUST_NOT, but not handled) + assertFalse("Single FILTER clause should not be approximatable", query.canApproximate(mockContext)); } // Test BoolQueryBuilder pattern: Single clause WITH ApproximateScoreQuery wrapper @@ -238,8 +449,8 @@ public void testSingleClauseWithApproximateScoreQueryCanApproximate() { // Create ApproximateScoreQuery wrapper (as BoolQueryBuilder would) ApproximatePointRangeQuery approxQuery = new ApproximatePointRangeQuery( "field", - new byte[] { 1 }, - new byte[] { 100 }, + IntPoint.pack(new int[] { 1 }).bytes, + IntPoint.pack(new int[] { 100 }).bytes, 1, ApproximatePointRangeQuery.LONG_FORMAT ); @@ -267,8 +478,8 @@ public void testSingleClauseWithApproximateScoreQueryCanApproximate() { public void testSingleClauseMustNotCannotApproximate() { ApproximatePointRangeQuery approxQuery = new ApproximatePointRangeQuery( "field", - new byte[] { 1 }, - new byte[] { 100 }, + IntPoint.pack(new int[] { 1 }).bytes, + IntPoint.pack(new int[] { 100 }).bytes, 1, ApproximatePointRangeQuery.LONG_FORMAT ); @@ -290,8 +501,8 @@ public void testNestedSingleClauseWithApproximateScoreQuery() { // Create inner ApproximateScoreQuery manually (verbose version) ApproximatePointRangeQuery innerApproxQuery = new ApproximatePointRangeQuery( "inner_field", - new byte[] { 50 }, - new byte[] { (byte) 150 }, + IntPoint.pack(new int[] { 50 }).bytes, + IntPoint.pack(new int[] { 150 }).bytes, 1, ApproximatePointRangeQuery.INT_FORMAT ); @@ -324,8 +535,8 @@ public void testNestedMultiClauseWithApproximateScoreQuery() { IntPoint.newRangeQuery("inner_field1", 50, 150), new ApproximatePointRangeQuery( "inner_field1", - new byte[] { 50 }, - new byte[] { (byte) 150 }, + IntPoint.pack(new int[] { 50 }).bytes, + IntPoint.pack(new int[] { 150 }).bytes, 1, ApproximatePointRangeQuery.INT_FORMAT ) @@ -334,8 +545,8 @@ public void testNestedMultiClauseWithApproximateScoreQuery() { IntPoint.newRangeQuery("inner_field2", 200, 300), new ApproximatePointRangeQuery( "inner_field2", - new byte[] { (byte) 200 }, - new byte[] { (byte) 300 }, + IntPoint.pack(new int[] { 200 }).bytes, + IntPoint.pack(new int[] { 300 }).bytes, 1, ApproximatePointRangeQuery.INT_FORMAT ) @@ -352,7 +563,13 @@ public void testNestedMultiClauseWithApproximateScoreQuery() { // Create outer ApproximateScoreQuery manually ApproximateScoreQuery outerFieldQuery = new ApproximateScoreQuery( IntPoint.newRangeQuery("outer_field", 1, 100), - new ApproximatePointRangeQuery("outer_field", new byte[] { 1 }, new byte[] { 100 }, 1, ApproximatePointRangeQuery.INT_FORMAT) + new ApproximatePointRangeQuery( + "outer_field", + IntPoint.pack(new int[] { 1 }).bytes, + IntPoint.pack(new int[] { 100 }).bytes, + 1, + ApproximatePointRangeQuery.INT_FORMAT + ) ); // Outer boolean query (multi-clause with nested) @@ -391,11 +608,23 @@ public void testDeeplyNestedBooleanQueries() { // Level 3 (deepest) - Create ApproximateScoreQuery manually ApproximateScoreQuery deep1Query = new ApproximateScoreQuery( IntPoint.newRangeQuery("deep_field1", 1, 50), - new ApproximatePointRangeQuery("deep_field1", new byte[] { 1 }, new byte[] { 50 }, 1, ApproximatePointRangeQuery.INT_FORMAT) + new ApproximatePointRangeQuery( + "deep_field1", + IntPoint.pack(new int[] { 1 }).bytes, + IntPoint.pack(new int[] { 50 }).bytes, + 1, + ApproximatePointRangeQuery.INT_FORMAT + ) ); ApproximateScoreQuery deep2Query = new ApproximateScoreQuery( IntPoint.newRangeQuery("deep_field2", 51, 100), - new ApproximatePointRangeQuery("deep_field2", new byte[] { 51 }, new byte[] { 100 }, 1, ApproximatePointRangeQuery.INT_FORMAT) + new ApproximatePointRangeQuery( + "deep_field2", + IntPoint.pack(new int[] { 51 }).bytes, + IntPoint.pack(new int[] { 100 }).bytes, + 1, + ApproximatePointRangeQuery.INT_FORMAT + ) ); BooleanQuery level3Query = new BooleanQuery.Builder().add(deep1Query, BooleanClause.Occur.FILTER) @@ -409,8 +638,8 @@ public void testDeeplyNestedBooleanQueries() { IntPoint.newRangeQuery("mid_field", 200, 300), new ApproximatePointRangeQuery( "mid_field", - new byte[] { (byte) 200 }, - new byte[] { (byte) 300 }, + IntPoint.pack(new int[] { 200 }).bytes, + IntPoint.pack(new int[] { 300 }).bytes, 1, ApproximatePointRangeQuery.INT_FORMAT ) @@ -427,8 +656,8 @@ public void testDeeplyNestedBooleanQueries() { IntPoint.newRangeQuery("top_field", 400, 500), new ApproximatePointRangeQuery( "top_field", - new byte[] { (byte) 400 }, - new byte[] { (byte) 500 }, + IntPoint.pack(new int[] { 400 }).bytes, + IntPoint.pack(new int[] { 500 }).bytes, 1, ApproximatePointRangeQuery.INT_FORMAT ) @@ -500,6 +729,278 @@ public void testNestedQueryWithMustNotClause() { assertFalse("Nested query with MUST_NOT should not be approximatable", outerQuery.canApproximate(mockContext)); } + // Test BulkScorer windowed approach with small dataset + public void testBulkScorerWindowedExpansionSmall() throws IOException { + try (Directory directory = newDirectory()) { + try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) { + // Add documents with overlapping ranges + for (int i = 0; i < 1000; i++) { + Document doc = new Document(); + doc.add(new IntPoint("field1", i)); + doc.add(new IntPoint("field2", i % 100)); // Create overlapping ranges + iw.addDocument(doc); + } + iw.flush(); + + try (IndexReader reader = iw.getReader()) { + IndexSearcher searcher = new IndexSearcher(reader); + LeafReaderContext leafContext = reader.leaves().get(0); + + ApproximateScoreQuery approxQuery1 = new ApproximateScoreQuery( + IntPoint.newRangeQuery("field1", 100, 900), + new ApproximatePointRangeQuery( + "field1", + IntPoint.pack(new int[] { 100 }).bytes, + IntPoint.pack(new int[] { 900 }).bytes, + 1, + ApproximatePointRangeQuery.INT_FORMAT + ) + ); + ApproximateScoreQuery approxQuery2 = new ApproximateScoreQuery( + IntPoint.newRangeQuery("field2", 10, 90), + new ApproximatePointRangeQuery( + "field2", + IntPoint.pack(new int[] { 10 }).bytes, + IntPoint.pack(new int[] { 90 }).bytes, + 1, + ApproximatePointRangeQuery.INT_FORMAT + ) + ); + + BooleanQuery boolQuery = new BooleanQuery.Builder().add(approxQuery1, BooleanClause.Occur.FILTER) + .add(approxQuery2, BooleanClause.Occur.FILTER) + .build(); + ApproximateBooleanQuery query = new ApproximateBooleanQuery(boolQuery); + + Weight weight = query.createWeight(searcher, ScoreMode.COMPLETE, 1.0f); + ScorerSupplier supplier = weight.scorerSupplier(leafContext); + BulkScorer bulkScorer = supplier.bulkScorer(); + + assertNotNull(bulkScorer); + + // Test bulk scoring with collection + List collectedDocs = new ArrayList<>(); + LeafCollector collector = new LeafCollector() { + @Override + public void setScorer(Scorable scorer) throws IOException {} + + @Override + public void collect(int doc) throws IOException { + collectedDocs.add(doc); + } + }; + + int result = bulkScorer.score(collector, null, 0, Integer.MAX_VALUE); + + // Should collect documents + assertTrue("Should collect some documents", collectedDocs.size() > 0); + assertTrue("Should collect reasonable number of documents", collectedDocs.size() <= 1000); + } + } + } + } + + // Test BulkScorer with large dataset to trigger windowed expansion + public void testBulkScorerWindowedExpansionLarge() throws IOException { + try (Directory directory = newDirectory()) { + try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) { + int numDocs = 20000; + for (int i = 0; i < numDocs; i++) { + Document doc = new Document(); + doc.add(new IntPoint("field1", i)); + doc.add(new IntPoint("field2", i % 1000)); // Create dense overlapping ranges + iw.addDocument(doc); + } + iw.flush(); + + try (IndexReader reader = iw.getReader()) { + IndexSearcher searcher = new IndexSearcher(reader); + LeafReaderContext leafContext = reader.leaves().get(0); + + ApproximateScoreQuery approxQuery1 = new ApproximateScoreQuery( + IntPoint.newRangeQuery("field1", 1000, 20000), + new ApproximatePointRangeQuery( + "field1", + IntPoint.pack(new int[] { 1000 }).bytes, + IntPoint.pack(new int[] { 20000 }).bytes, + 1, + ApproximatePointRangeQuery.INT_FORMAT + ) + ); + ApproximateScoreQuery approxQuery2 = new ApproximateScoreQuery( + IntPoint.newRangeQuery("field2", 100, 900), + new ApproximatePointRangeQuery( + "field2", + IntPoint.pack(new int[] { 100 }).bytes, + IntPoint.pack(new int[] { 900 }).bytes, + 1, + ApproximatePointRangeQuery.INT_FORMAT + ) + ); + + BooleanQuery boolQuery = new BooleanQuery.Builder().add(approxQuery1, BooleanClause.Occur.FILTER) + .add(approxQuery2, BooleanClause.Occur.FILTER) + .build(); + ApproximateBooleanQuery query = new ApproximateBooleanQuery(boolQuery); + + Weight weight = query.createWeight(searcher, ScoreMode.COMPLETE, 1.0f); + ScorerSupplier supplier = weight.scorerSupplier(leafContext); + BulkScorer bulkScorer = supplier.bulkScorer(); + + assertNotNull(bulkScorer); + + // Test bulk scoring with collection + List collectedDocs = new ArrayList<>(); + LeafCollector collector = new LeafCollector() { + @Override + public void setScorer(Scorable scorer) throws IOException {} + + @Override + public void collect(int doc) throws IOException { + collectedDocs.add(doc); + } + }; + + int result = bulkScorer.score(collector, null, 0, Integer.MAX_VALUE); + + // Should collect documents and potentially expand windows + assertTrue("Should collect some documents", collectedDocs.size() > 0); + assertTrue("Should collect up to 10k documents or exhaust", collectedDocs.size() <= 10000); + } + } + } + } + + // Integration test validating hit count and accuracy + public void testApproximateResultsValidation() throws IOException { + try (Directory directory = newDirectory()) { + try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) { + int numDocs = 20000; + for (int i = 0; i < numDocs; i++) { + Document doc = new Document(); + int field1Value = i % 10; + int field2Value = (i * 2) % 10; + doc.add(new IntPoint("field1", field1Value)); + doc.add(new IntPoint("field2", field2Value)); + doc.add(new NumericDocValuesField("field1", field1Value)); + doc.add(new NumericDocValuesField("field2", field2Value)); + doc.add(new StoredField("field1", field1Value)); + doc.add(new StoredField("field2", field2Value)); + iw.addDocument(doc); + } + iw.flush(); + + try (IndexReader reader = iw.getReader()) { + + SearchContext searchContext = mock(SearchContext.class); + IndexShard indexShard = mock(IndexShard.class); + when(searchContext.indexShard()).thenReturn(indexShard); + SearchOperationListener searchOperationListener = new SearchOperationListener() { + }; + when(indexShard.getSearchOperationListener()).thenReturn(searchOperationListener); + when(searchContext.bucketCollectorProcessor()).thenReturn(new BucketCollectorProcessor()); + when(searchContext.asLocalBucketCountThresholds(any())).thenCallRealMethod(); + + // ContextIndexSearcher searcher = mock(ContextIndexSearcher.class); + ContextIndexSearcher searcher = new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + mock(ExecutorService.class), + searchContext + ); + + searcher.addQueryCancellation(() -> {}); + + int lower1 = 2; + int upper1 = 5; + int lower2 = 4; + int upper2 = 5; + + // Create approximate query + ApproximatePointRangeQuery approxQuery1 = new ApproximatePointRangeQuery( + "field1", + IntPoint.pack(new int[] { lower1 }).bytes, + IntPoint.pack(new int[] { upper1 }).bytes, + 1, + ApproximatePointRangeQuery.INT_FORMAT + ); + ApproximatePointRangeQuery approxQuery2 = new ApproximatePointRangeQuery( + "field2", + IntPoint.pack(new int[] { lower2 }).bytes, + IntPoint.pack(new int[] { upper2 }).bytes, + 1, + ApproximatePointRangeQuery.INT_FORMAT + ); + + BooleanQuery approximateBoolQuery = new BooleanQuery.Builder().add(approxQuery1, BooleanClause.Occur.FILTER) + .add(approxQuery2, BooleanClause.Occur.FILTER) + .build(); + ApproximateBooleanQuery approximateQuery = new ApproximateBooleanQuery(approximateBoolQuery); + + // Create exact query (regular Lucene BooleanQuery) + BooleanQuery exactBoolQuery = new BooleanQuery.Builder().add( + IntPoint.newRangeQuery("field1", lower1, upper1), + BooleanClause.Occur.FILTER + ).add(IntPoint.newRangeQuery("field2", lower2, upper2), BooleanClause.Occur.FILTER).build(); + + TopScoreDocCollector collector = new TopScoreDocCollectorManager(10001, 10001).newCollector(); + + searcher.search(approximateQuery, collector); + + // Search with both queries + TopDocs approximateDocs = collector.topDocs(); + + TopScoreDocCollector collectorExact = new TopScoreDocCollectorManager(10001, 10001).newCollector(); + + searcher.search(exactBoolQuery, collectorExact); + + // Search with both queries + TopDocs exactDocs = collectorExact.topDocs(); + + System.out.println("Exact hits: " + exactDocs.totalHits.value()); + System.out.println("Approximate hits: " + approximateDocs.totalHits.value()); + System.out.println("approximate score docs length: " + approximateDocs.scoreDocs.length); + // Validate hit count logic + if (exactDocs.totalHits.value() <= 10000) { + assertEquals( + "When exact results ≤ 10k, approximate should match exactly", + exactDocs.totalHits.value(), + approximateDocs.totalHits.value() + ); + } else { + assertEquals( + "Approximate should return exactly 10k hits when exact > 10k", + 10000, + approximateDocs.totalHits.value() + ); + } + + // Validate hit accuracy - each returned doc should match the query criteria + StoredFields storedFields = reader.storedFields(); + for (int i = 0; i < approximateDocs.scoreDocs.length; i++) { + int docId = approximateDocs.scoreDocs[i].doc; + Document doc = storedFields.document(docId); + + int field1Value = doc.getField("field1").numericValue().intValue(); + int field2Value = doc.getField("field2").numericValue().intValue(); + + assertTrue( + "field1 should be in range [" + lower1 + ", " + upper1 + "], got: " + field1Value, + field1Value >= lower1 && field1Value <= upper1 + ); + assertTrue( + "field2 should be in range [" + lower2 + ", " + upper2 + "], got: " + field2Value, + field2Value >= lower2 && field2Value <= upper2 + ); + } + } + } + } + } + // Test window size heuristic with different cost scenarios public void testWindowSizeHeuristic() throws IOException { try (Directory directory = newDirectory()) { @@ -535,6 +1036,75 @@ public void testWindowSizeHeuristic() throws IOException { } } + // Test sparse data distribution (simulating http_logs dataset) + public void testSparseDataDistribution() throws IOException { + try (Directory directory = newDirectory()) { + try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) { + String fieldName1 = "timestamp"; + String fieldName2 = "status_code"; + + // Create sparse timestamp distribution with dense status code clusters + for (int i = 0; i < 10000; i++) { + Document doc = new Document(); + // Sparse timestamps (gaps in time) + int timestamp = i * 10 + (i % 6); + // Dense status code clusters (200s, 400s, 500s) + int statusCode = (i % 100) < 70 ? 200 + (i % 11) : ((i % 100) < 80 ? 400 + (i % 11) : 500 + (i % 11)); + + doc.add(new IntPoint(fieldName1, timestamp)); + doc.add(new IntPoint(fieldName2, statusCode)); + doc.add(new NumericDocValuesField(fieldName1, timestamp)); + doc.add(new NumericDocValuesField(fieldName2, statusCode)); + iw.addDocument(doc); + } + iw.flush(); + + try (IndexReader reader = iw.getReader()) { + IndexSearcher searcher = new IndexSearcher(reader); + + // Test query for specific time range and status codes + testApproximateQueryValidation(searcher, fieldName1, fieldName2, 10000, 50000, 200, 500, 100); + testApproximateQueryValidation(searcher, fieldName1, fieldName2, 0, 20000, 404, 404, 50); + } + } + } + } + + // Test dense data distribution (simulating nyc_taxis dataset) + public void testDenseDataDistribution() throws IOException { + try (Directory directory = newDirectory()) { + try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) { + String fieldName1 = "fare_amount"; + String fieldName2 = "trip_distance"; + + // Create dense overlapping distributions + for (int fare = 500; fare <= 5000; fare += 50) { // Dense fare distribution + for (int distance = 1; distance <= 50; distance += 2) { // Dense distance distribution + // Add multiple documents per combination to create density + int numDocs = 3; + for (int d = 0; d < numDocs; d++) { + Document doc = new Document(); + doc.add(new IntPoint(fieldName1, fare)); + doc.add(new IntPoint(fieldName2, distance)); + doc.add(new NumericDocValuesField(fieldName1, fare)); + doc.add(new NumericDocValuesField(fieldName2, distance)); + iw.addDocument(doc); + } + } + } + iw.flush(); + + try (IndexReader reader = iw.getReader()) { + IndexSearcher searcher = new IndexSearcher(reader); + + // Test queries for different fare and distance ranges + testApproximateQueryValidation(searcher, fieldName1, fieldName2, 1000, 3000, 5, 25, 200); + testApproximateQueryValidation(searcher, fieldName1, fieldName2, 2000, 4000, 10, 40, 500); + } + } + } + } + public void testApproximateQueryValidation( IndexSearcher searcher, String field1, @@ -546,8 +1116,29 @@ public void testApproximateQueryValidation( int size ) throws IOException { // Test with approximate query - BooleanQuery boolQuery = new BooleanQuery.Builder().add(IntPoint.newRangeQuery(field1, lower1, upper1), BooleanClause.Occur.FILTER) - .add(IntPoint.newRangeQuery(field2, lower2, upper2), BooleanClause.Occur.FILTER) + ApproximateScoreQuery approxQuery1 = new ApproximateScoreQuery( + IntPoint.newRangeQuery(field1, lower1, upper1), + new ApproximatePointRangeQuery( + field1, + IntPoint.pack(new int[] { lower1 }).bytes, + IntPoint.pack(new int[] { upper1 }).bytes, + 1, + ApproximatePointRangeQuery.INT_FORMAT + ) + ); + ApproximateScoreQuery approxQuery2 = new ApproximateScoreQuery( + IntPoint.newRangeQuery(field2, lower2, upper2), + new ApproximatePointRangeQuery( + field2, + IntPoint.pack(new int[] { lower2 }).bytes, + IntPoint.pack(new int[] { upper2 }).bytes, + 1, + ApproximatePointRangeQuery.INT_FORMAT + ) + ); + + BooleanQuery boolQuery = new BooleanQuery.Builder().add(approxQuery1, BooleanClause.Occur.FILTER) + .add(approxQuery2, BooleanClause.Occur.FILTER) .build(); ApproximateScoreQuery approxQuery = new ApproximateScoreQuery(boolQuery, new ApproximateBooleanQuery(boolQuery)); From 9f2a326fac549fc1cf9f5dd5882a9a7c4d750c44 Mon Sep 17 00:00:00 2001 From: Sawan Srivastava Date: Fri, 22 Aug 2025 04:35:47 +0000 Subject: [PATCH 33/38] fixed duplicate doc collecting + more integ tests Signed-off-by: Sawan Srivastava --- .../ApproximateBooleanScorerSupplier.java | 50 +- .../ApproximatePointRangeQuery.java | 56 +- .../ApproximateBooleanQueryTests.java | 584 +++++++----------- 3 files changed, 292 insertions(+), 398 deletions(-) diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java index c692b33c9d681..501f530eca456 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.BitSet; import java.util.List; /** @@ -35,6 +36,7 @@ public class ApproximateBooleanScorerSupplier extends ScorerSupplier { private final float boost; private final int size; private long cost = -1; + private int scalingFactor = 3; /** * Creates a new ApproximateBooleanScorerSupplier. @@ -69,6 +71,10 @@ public ApproximateBooleanScorerSupplier( } } + public void setScalingWindowFactor(int factor) { + scalingFactor = factor; + } + /** * Get the {@link Scorer}. This may not return {@code null} and must be called at most once. * @@ -80,7 +86,7 @@ public Scorer get(long leadCost) throws IOException { return null; } - // Create appropriate iterators for each clause - ResumableDISI only for approximatable queries + // Create appropriate iterators for each clause List clauseIterators = new ArrayList<>(clauseWeights.size()); for (int i = 0; i < clauseWeights.size(); i++) { // Use regular DocIdSetIterator for non-approximatable queries @@ -179,6 +185,7 @@ public int docID() { // Create a simple bulk scorer that wraps the conjunction return new BulkScorer() { private int totalCollected = 0; + private BitSet collectedDocs = new BitSet(); // Track collected documents // Windowed approach state private int currentWindowSize = initialWindowSize; @@ -186,6 +193,8 @@ public int docID() { private List rebuildIteratorsWithWindowSize(int windowSize) throws IOException { List newIterators = new ArrayList<>(); + boolean allClausesFullyTraversed = true; + for (int i = 0; i < clauseWeights.size(); i++) { Weight weight = clauseWeights.get(i); ScorerSupplier supplier = cachedSuppliers.get(i); // Use cached supplier @@ -196,6 +205,7 @@ private List rebuildIteratorsWithWindowSize(int windowSize) th // Temporarily set the size int originalSize = approxQuery.getSize(); approxQuery.setSize(windowSize); + try { Scorer scorer = supplier.get(windowSize); if (scorer == null) { @@ -203,16 +213,25 @@ private List rebuildIteratorsWithWindowSize(int windowSize) th return null; } newIterators.add(scorer.iterator()); + + // Check if this clause has been fully traversed + if (!approxQuery.getFullyTraversed()) { + allClausesFullyTraversed = false; + } } finally { // Restore original size approxQuery.setSize(originalSize); } } else { - // Regular queries use full cost + // Regular queries use full cost - always fully traversed Scorer scorer = supplier.get(supplier.cost()); newIterators.add(scorer.iterator()); } } + + // If all approximatable clauses are fully traversed, we still have valid scorers + // Don't return null - we have valid scorers that contain all the data + return newIterators; } @@ -222,7 +241,7 @@ public int score(LeafCollector collector, Bits acceptDocs, int min, int max) thr // Check if we need to expand window if (totalCollected < size && (globalConjunction == null || globalConjunction.docID() == DocIdSetIterator.NO_MORE_DOCS)) { - currentWindowSize *= 3; + currentWindowSize *= scalingFactor; // Rebuild iterators with new window size List newIterators = rebuildIteratorsWithWindowSize(currentWindowSize); @@ -232,11 +251,11 @@ public int score(LeafCollector collector, Bits acceptDocs, int min, int max) thr } globalConjunction = ConjunctionUtils.intersectIterators(newIterators); - // Return first docID from new conjunction (could be < min) + // Return first docID from new conjunction to reset min int firstDoc = globalConjunction.nextDoc(); if (firstDoc != DocIdSetIterator.NO_MORE_DOCS) { return firstDoc; // CancellableBulkScorer will use this as new min - } else {} + } } // Score existing conjunction within [min, max) range @@ -268,38 +287,47 @@ private int scoreExistingConjunction(LeafCollector collector, Bits acceptDocs, i return DocIdSetIterator.NO_MORE_DOCS; // Early termination } - if (acceptDocs == null || acceptDocs.get(doc)) { + // BitSet duplicate detection - only collect if not already collected + if (!collectedDocs.get(doc) && (acceptDocs == null || acceptDocs.get(doc))) { + collectedDocs.set(doc); // Mark as collected collector.collect(doc); collected++; totalCollected++; + } else if (collectedDocs.get(doc)) { } } // Check if conjunction exhausted if (globalConjunction.docID() == DocIdSetIterator.NO_MORE_DOCS) { - // If we need more hits, expand immediately if (totalCollected < size) { - currentWindowSize *= 3; + int oldWindowSize = currentWindowSize; + currentWindowSize *= scalingFactor; try { List newIterators = rebuildIteratorsWithWindowSize(currentWindowSize); if (newIterators == null) { - // A clause is fully traversed, end conjunction + // A clause is fully traversed, restore window size and end conjunction + currentWindowSize = oldWindowSize; return DocIdSetIterator.NO_MORE_DOCS; } + + // Expansion succeeded globalConjunction = ConjunctionUtils.intersectIterators(newIterators); + // Start fresh from beginning of new conjunction int firstDoc = globalConjunction.nextDoc(); if (firstDoc != DocIdSetIterator.NO_MORE_DOCS) { - return firstDoc; // Return new starting point + return firstDoc; // Return new starting point to reset min } + return firstDoc; // Return new starting point + } catch (IOException e) {} } - } return globalConjunction.docID(); + } @Override diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java index 4aef6c0336012..07d51df99b276 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java @@ -58,6 +58,8 @@ public class ApproximatePointRangeQuery extends ApproximateQuery { public PointRangeQuery pointRangeQuery; private final Function valueToString; + private boolean hasBeenFullyTraversed = false; + public ApproximatePointRangeQuery( String field, byte[] lowerPoint, @@ -104,6 +106,10 @@ public void setSortOrder(SortOrder sortOrder) { this.sortOrder = sortOrder; } + public boolean getFullyTraversed() { + return hasBeenFullyTraversed; + } + @Override public Query rewrite(IndexSearcher indexSearcher) throws IOException { return super.rewrite(indexSearcher); @@ -391,6 +397,7 @@ public void intersectRight(PointValues.IntersectVisitor visitor, PointValues.Poi @Override public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { + hasBeenFullyTraversed = false; LeafReader reader = context.reader(); long[] docCount = { 0 }; @@ -398,39 +405,18 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti if (checkValidPointValues(values) == false) { return null; } - // values.size(): total points indexed, In most cases: values.size() ≈ number of documents (assuming single-valued fields) - if (size > values.size()) { - ScorerSupplier luceneSupplier = pointRangeQueryWeight.scorerSupplier(context); - return new ScorerSupplier() { - boolean alreadyFullyTraversed = false; - - @Override - public Scorer get(long leadCost) throws IOException { - return getWithSize(size); - } - - public Scorer getWithSize(int dynamicSize) throws IOException { - if (alreadyFullyTraversed) { - return null; // Signal end of conjunction - } - alreadyFullyTraversed = true; - return luceneSupplier.get(Long.MAX_VALUE); - } - - @Override - public long cost() { - return luceneSupplier.cost(); - } - }; + // values.size(): total points indexed, In most cases: values.size() ≈ number of documents (assuming single-valued fields) + if (size > values.size() && context.isTopLevel) { + return pointRangeQueryWeight.scorerSupplier(context); } else { if (sortOrder == null || sortOrder.equals(SortOrder.ASC)) { return new ScorerSupplier() { - // Keep a visitor for cost estimation only DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values); PointValues.IntersectVisitor visitor = getIntersectVisitor(result, docCount); long cost = -1; + long lastDoc = -1; @Override public Scorer get(long leadCost) throws IOException { @@ -450,7 +436,9 @@ public Scorer getWithSize(int dynamicSize) throws IOException { size = dynamicSize; try { - + if (size > values.size()) { + hasBeenFullyTraversed = true; + } // For windowed approach, create fresh iterator without ResumableDISI state DocIdSetBuilder freshResult = new DocIdSetBuilder(reader.maxDoc(), values); long[] freshDocCount = new long[1]; @@ -460,6 +448,7 @@ public Scorer getWithSize(int dynamicSize) throws IOException { intersectLeft(values.getPointTree(), freshVisitor, freshDocCount); DocIdSetIterator iterator = freshResult.build().iterator(); + lastDoc = iterator.docIDRunEnd(); return new ConstantScoreScorer(score(), scoreMode, iterator); } finally { // Restore original size @@ -470,12 +459,21 @@ public Scorer getWithSize(int dynamicSize) throws IOException { @Override public long cost() { if (cost == -1) { - // Computing the cost may be expensive, so only do it if necessary - cost = values.estimateDocCount(visitor); - assert cost >= 0; + if (context.isTopLevel) { + // Computing the cost may be expensive, so only do it if necessary + cost = values.estimateDocCount(visitor); + assert cost >= 0; + } else { + return lastDoc != -1 ? lastDoc : values.estimateDocCount(visitor); + } } return cost; } + + public int getBKDSize() throws IOException { + return reader.getPointValues(pointRangeQuery.getField()).getDocCount(); + } + }; } else { // we need to fetch size + deleted docs since the collector will prune away deleted docs resulting in fewer results diff --git a/server/src/test/java/org/opensearch/search/approximate/ApproximateBooleanQueryTests.java b/server/src/test/java/org/opensearch/search/approximate/ApproximateBooleanQueryTests.java index 07a1b7929896a..89c49ffc853ab 100644 --- a/server/src/test/java/org/opensearch/search/approximate/ApproximateBooleanQueryTests.java +++ b/server/src/test/java/org/opensearch/search/approximate/ApproximateBooleanQueryTests.java @@ -18,11 +18,8 @@ import org.apache.lucene.index.StoredFields; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; -import org.apache.lucene.search.BulkScorer; import org.apache.lucene.search.IndexSearcher; -import org.apache.lucene.search.LeafCollector; import org.apache.lucene.search.Query; -import org.apache.lucene.search.Scorable; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.ScorerSupplier; @@ -42,9 +39,7 @@ import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; -import java.util.ArrayList; import java.util.Arrays; -import java.util.List; import java.util.concurrent.ExecutorService; import static org.mockito.ArgumentMatchers.any; @@ -232,13 +227,6 @@ public void testApproximateVsExactResults() throws IOException { ApproximatePointRangeQuery.INT_FORMAT ); - // ApproximateScoreQuery approxQuery1 = new ApproximateScoreQuery(IntPoint.newRangeQuery("field1", lower1, upper1), new - // ApproximatePointRangeQuery("field1", IntPoint.pack(new int[]{lower1}).bytes, IntPoint.pack(new int[]{upper1}).bytes, - // 1, ApproximatePointRangeQuery.INT_FORMAT)); - // ApproximateScoreQuery approxQuery2 = new ApproximateScoreQuery(IntPoint.newRangeQuery("field2", lower2, upper2), new - // ApproximatePointRangeQuery("field2", IntPoint.pack(new int[]{lower2}).bytes, IntPoint.pack(new int[]{upper2}).bytes, - // 1, ApproximatePointRangeQuery.INT_FORMAT)); - BooleanQuery boolQuery = new BooleanQuery.Builder().add(approxQuery1, BooleanClause.Occur.FILTER) .add(approxQuery2, BooleanClause.Occur.FILTER) .build(); @@ -252,6 +240,9 @@ public void testApproximateVsExactResults() throws IOException { TopDocs approximateDocs = searcher.search(approximateQuery, 1000); TopDocs exactDocs = searcher.search(exactQuery, 1000); + System.out.println("Exact docs total hits: " + exactDocs.totalHits.value()); + System.out.println("Approx docs total hits: " + approximateDocs.totalHits.value()); + // Results should be identical when approximation is not triggered // or when we collect all available documents if (exactDocs.totalHits.value() <= 1000) { @@ -266,59 +257,6 @@ public void testApproximateVsExactResults() throws IOException { } } - // Test early termination at 10k hits - public void testEarlyTerminationAt10k() throws IOException { - try (Directory directory = newDirectory()) { - try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) { - // Create enough documents to exceed 10k hits - for (int i = 0; i < 20000; i++) { - Document doc = new Document(); - doc.add(new IntPoint("field1", i % 100)); // High overlap - doc.add(new IntPoint("field2", i % 50)); // High overlap - iw.addDocument(doc); - } - iw.flush(); - - try (IndexReader reader = iw.getReader()) { - IndexSearcher searcher = new IndexSearcher(reader); - - // Create query that should match many documents - - ApproximateScoreQuery approxQuery1 = new ApproximateScoreQuery( - IntPoint.newRangeQuery("field1", 0, 99), - new ApproximatePointRangeQuery( - "field1", - IntPoint.pack(new int[] { 0 }).bytes, - IntPoint.pack(new int[] { 99 }).bytes, - 1, - ApproximatePointRangeQuery.INT_FORMAT - ) - ); - ApproximateScoreQuery approxQuery2 = new ApproximateScoreQuery( - IntPoint.newRangeQuery("field2", 0, 49), - new ApproximatePointRangeQuery( - "field2", - IntPoint.pack(new int[] { 0 }).bytes, - IntPoint.pack(new int[] { 49 }).bytes, - 1, - ApproximatePointRangeQuery.INT_FORMAT - ) - ); - - BooleanQuery boolQuery = new BooleanQuery.Builder().add(approxQuery1, BooleanClause.Occur.FILTER) - .add(approxQuery2, BooleanClause.Occur.FILTER) - .build(); - ApproximateBooleanQuery query = new ApproximateBooleanQuery(boolQuery); - - TopDocs docs = searcher.search(query, 15000); - - // Should terminate early at exactly 10k hits - assertEquals("Should collect exactly 10k documents", 10000, docs.totalHits.value()); - } - } - } - } - // Test with single clause (nested ApproximateScoreQuery case) public void testSingleClauseApproximation() { ApproximatePointRangeQuery pointQuery = new ApproximatePointRangeQuery( @@ -729,308 +667,238 @@ public void testNestedQueryWithMustNotClause() { assertFalse("Nested query with MUST_NOT should not be approximatable", outerQuery.canApproximate(mockContext)); } - // Test BulkScorer windowed approach with small dataset - public void testBulkScorerWindowedExpansionSmall() throws IOException { - try (Directory directory = newDirectory()) { - try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) { - // Add documents with overlapping ranges - for (int i = 0; i < 1000; i++) { - Document doc = new Document(); - doc.add(new IntPoint("field1", i)); - doc.add(new IntPoint("field2", i % 100)); // Create overlapping ranges - iw.addDocument(doc); - } - iw.flush(); - - try (IndexReader reader = iw.getReader()) { - IndexSearcher searcher = new IndexSearcher(reader); - LeafReaderContext leafContext = reader.leaves().get(0); - - ApproximateScoreQuery approxQuery1 = new ApproximateScoreQuery( - IntPoint.newRangeQuery("field1", 100, 900), - new ApproximatePointRangeQuery( - "field1", - IntPoint.pack(new int[] { 100 }).bytes, - IntPoint.pack(new int[] { 900 }).bytes, - 1, - ApproximatePointRangeQuery.INT_FORMAT - ) - ); - ApproximateScoreQuery approxQuery2 = new ApproximateScoreQuery( - IntPoint.newRangeQuery("field2", 10, 90), - new ApproximatePointRangeQuery( - "field2", - IntPoint.pack(new int[] { 10 }).bytes, - IntPoint.pack(new int[] { 90 }).bytes, - 1, - ApproximatePointRangeQuery.INT_FORMAT - ) - ); - - BooleanQuery boolQuery = new BooleanQuery.Builder().add(approxQuery1, BooleanClause.Occur.FILTER) - .add(approxQuery2, BooleanClause.Occur.FILTER) - .build(); - ApproximateBooleanQuery query = new ApproximateBooleanQuery(boolQuery); - - Weight weight = query.createWeight(searcher, ScoreMode.COMPLETE, 1.0f); - ScorerSupplier supplier = weight.scorerSupplier(leafContext); - BulkScorer bulkScorer = supplier.bulkScorer(); - - assertNotNull(bulkScorer); - - // Test bulk scoring with collection - List collectedDocs = new ArrayList<>(); - LeafCollector collector = new LeafCollector() { - @Override - public void setScorer(Scorable scorer) throws IOException {} - - @Override - public void collect(int doc) throws IOException { - collectedDocs.add(doc); - } - }; - - int result = bulkScorer.score(collector, null, 0, Integer.MAX_VALUE); + // Test BulkScorer with large dataset to trigger windowed expansion + // public void testBulkScorerWindowedExpansion() throws IOException { + // try (Directory directory = newDirectory()) { + // try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) { + // int numDocs = 20000; + // for (int i = 0; i < numDocs; i++) { + // Document doc = new Document(); + // doc.add(new IntPoint("field1", i)); + // doc.add(new IntPoint("field2", i % 1000)); // Create dense overlapping ranges + // doc.add(new NumericDocValuesField("field1", i)); + // doc.add(new NumericDocValuesField("field2", i % 1000)); + // doc.add(new StoredField("field1", i)); + // doc.add(new StoredField("field2", i % 1000)); + // iw.addDocument(doc); + // } + // iw.flush(); + // + // try (IndexReader reader = iw.getReader()) { + // ContextIndexSearcher searcher = createContextIndexSearcher(reader); + // + // // Create approximate queries directly + // ApproximatePointRangeQuery approxQuery1 = new ApproximatePointRangeQuery( + // "field1", + // IntPoint.pack(new int[] { 1000 }).bytes, + // IntPoint.pack(new int[] { 20000 }).bytes, + // 1, + // ApproximatePointRangeQuery.INT_FORMAT + // ); + // ApproximatePointRangeQuery approxQuery2 = new ApproximatePointRangeQuery( + // "field2", + // IntPoint.pack(new int[] { 100 }).bytes, + // IntPoint.pack(new int[] { 900 }).bytes, + // 1, + // ApproximatePointRangeQuery.INT_FORMAT + // ); + // + // BooleanQuery boolQuery = new BooleanQuery.Builder().add(approxQuery1, BooleanClause.Occur.FILTER) + // .add(approxQuery2, BooleanClause.Occur.FILTER) + // .build(); + // ApproximateBooleanQuery query = new ApproximateBooleanQuery(boolQuery); + // + // TopScoreDocCollector collector = new TopScoreDocCollectorManager(10001, 10001).newCollector(); + // searcher.search(query, collector); + // TopDocs docs = collector.topDocs(); + // + // System.out.println("ScoreDocs length: "+docs.scoreDocs.length); + // System.out.println("total hits value" + docs.totalHits.value()); + // // Should collect documents and potentially expand windows + // assertTrue("Should collect some documents", docs.scoreDocs.length > 0); + // assertTrue("Should collect up to 10k documents or exhaust", docs.scoreDocs.length <= 10001); + // } + // } + // } + // } + + /** + * Creates a ContextIndexSearcher with properly mocked SearchContext for testing. + */ + private ContextIndexSearcher createContextIndexSearcher(IndexReader reader) throws IOException { + SearchContext searchContext = mock(SearchContext.class); + IndexShard indexShard = mock(IndexShard.class); + when(searchContext.indexShard()).thenReturn(indexShard); + SearchOperationListener searchOperationListener = new SearchOperationListener() { + }; + when(indexShard.getSearchOperationListener()).thenReturn(searchOperationListener); + when(searchContext.bucketCollectorProcessor()).thenReturn(new BucketCollectorProcessor()); + when(searchContext.asLocalBucketCountThresholds(any())).thenCallRealMethod(); + + ContextIndexSearcher searcher = new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + mock(ExecutorService.class), + searchContext + ); - // Should collect documents - assertTrue("Should collect some documents", collectedDocs.size() > 0); - assertTrue("Should collect reasonable number of documents", collectedDocs.size() <= 1000); - } - } - } + searcher.addQueryCancellation(() -> {}); + return searcher; } - // Test BulkScorer with large dataset to trigger windowed expansion - public void testBulkScorerWindowedExpansionLarge() throws IOException { + // // Integration test validating hit count and accuracy + // public void testApproximateResultsValidation() throws IOException { + // try (Directory directory = newDirectory()) { + // try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) { + // int numDocs = 20000; + // for (int i = 0; i < numDocs; i++) { + // Document doc = new Document(); + // int field1Value = i % 1000; // Values: 0-999 (1000 unique values) + // int field2Value = i % 500; // Values: 0-499 (500 unique values) + // doc.add(new IntPoint("field1", field1Value)); + // doc.add(new IntPoint("field2", field2Value)); + // doc.add(new NumericDocValuesField("field1", field1Value)); + // doc.add(new NumericDocValuesField("field2", field2Value)); + // doc.add(new StoredField("field1", field1Value)); + // doc.add(new StoredField("field2", field2Value)); + // iw.addDocument(doc); + // } + // iw.flush(); + // + // try (IndexReader reader = iw.getReader()) { + // ContextIndexSearcher searcher = createContextIndexSearcher(reader); + // + // int lower1 = 100; + // int upper1 = 200; + // int lower2 = 50; + // int upper2 = 150; + // + // // Create approximate query + // ApproximatePointRangeQuery approxQuery1 = new ApproximatePointRangeQuery( + // "field1", + // IntPoint.pack(new int[] { lower1 }).bytes, + // IntPoint.pack(new int[] { upper1 }).bytes, + // 1, + // ApproximatePointRangeQuery.INT_FORMAT + // ); + // ApproximatePointRangeQuery approxQuery2 = new ApproximatePointRangeQuery( + // "field2", + // IntPoint.pack(new int[] { lower2 }).bytes, + // IntPoint.pack(new int[] { upper2 }).bytes, + // 1, + // ApproximatePointRangeQuery.INT_FORMAT + // ); + // + // BooleanQuery approximateBoolQuery = new BooleanQuery.Builder().add(approxQuery1, BooleanClause.Occur.FILTER) + // .add(approxQuery2, BooleanClause.Occur.FILTER) + // .build(); + // ApproximateBooleanQuery approximateQuery = new ApproximateBooleanQuery(approximateBoolQuery); + // + // // Create exact query (regular Lucene BooleanQuery) + // BooleanQuery exactBoolQuery = new BooleanQuery.Builder().add( + // IntPoint.newRangeQuery("field1", lower1, upper1), + // BooleanClause.Occur.FILTER + // ).add(IntPoint.newRangeQuery("field2", lower2, upper2), BooleanClause.Occur.FILTER).build(); + // + // TopScoreDocCollector collector = new TopScoreDocCollectorManager(10001, 10001).newCollector(); + // + // searcher.search(approximateQuery, collector); + // + // // Search with both queries + // TopDocs approximateDocs = collector.topDocs(); + // + // TopScoreDocCollector collectorExact = new TopScoreDocCollectorManager(10001, 10001).newCollector(); + // + // searcher.search(exactBoolQuery, collectorExact); + // + // // Search with both queries + // TopDocs exactDocs = collectorExact.topDocs(); + // + // System.out.println("Exact hits: " + exactDocs.totalHits.value()); + // System.out.println("Approximate hits: " + approximateDocs.totalHits.value()); + // System.out.println("approximate score docs length: " + approximateDocs.scoreDocs.length); + // // Validate hit count logic + // if (exactDocs.totalHits.value() <= 10000) { + // assertEquals( + // "When exact results ≤ 10k, approximate should match exactly", + // exactDocs.totalHits.value(), + // approximateDocs.totalHits.value() + // ); + // } else { + // assertEquals( + // "Approximate should return exactly 10k hits when exact > 10k", + // 10000, + // approximateDocs.totalHits.value() + // ); + // } + // + // // Validate hit accuracy - each returned doc should match the query criteria + // StoredFields storedFields = reader.storedFields(); + // for (int i = 0; i < approximateDocs.scoreDocs.length; i++) { + // int docId = approximateDocs.scoreDocs[i].doc; + // Document doc = storedFields.document(docId); + // + // int field1Value = doc.getField("field1").numericValue().intValue(); + // int field2Value = doc.getField("field2").numericValue().intValue(); + // + // assertTrue( + // "field1 should be in range [" + lower1 + ", " + upper1 + "], got: " + field1Value, + // field1Value >= lower1 && field1Value <= upper1 + // ); + // assertTrue( + // "field2 should be in range [" + lower2 + ", " + upper2 + "], got: " + field2Value, + // field2Value >= lower2 && field2Value <= upper2 + // ); + // } + // } + // } + // } + // } + + // Test window size heuristic with different cost scenarios + public void testWindowSizeHeuristic() throws IOException { try (Directory directory = newDirectory()) { try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) { - int numDocs = 20000; - for (int i = 0; i < numDocs; i++) { + for (int i = 0; i < 1000; i++) { Document doc = new Document(); doc.add(new IntPoint("field1", i)); - doc.add(new IntPoint("field2", i % 1000)); // Create dense overlapping ranges + doc.add(new IntPoint("field2", i * 2)); iw.addDocument(doc); } iw.flush(); try (IndexReader reader = iw.getReader()) { - IndexSearcher searcher = new IndexSearcher(reader); + ContextIndexSearcher searcher = createContextIndexSearcher(reader); LeafReaderContext leafContext = reader.leaves().get(0); - ApproximateScoreQuery approxQuery1 = new ApproximateScoreQuery( - IntPoint.newRangeQuery("field1", 1000, 20000), - new ApproximatePointRangeQuery( - "field1", - IntPoint.pack(new int[] { 1000 }).bytes, - IntPoint.pack(new int[] { 20000 }).bytes, - 1, - ApproximatePointRangeQuery.INT_FORMAT - ) - ); - ApproximateScoreQuery approxQuery2 = new ApproximateScoreQuery( - IntPoint.newRangeQuery("field2", 100, 900), - new ApproximatePointRangeQuery( - "field2", - IntPoint.pack(new int[] { 100 }).bytes, - IntPoint.pack(new int[] { 900 }).bytes, - 1, - ApproximatePointRangeQuery.INT_FORMAT - ) - ); - - BooleanQuery boolQuery = new BooleanQuery.Builder().add(approxQuery1, BooleanClause.Occur.FILTER) - .add(approxQuery2, BooleanClause.Occur.FILTER) - .build(); - ApproximateBooleanQuery query = new ApproximateBooleanQuery(boolQuery); - - Weight weight = query.createWeight(searcher, ScoreMode.COMPLETE, 1.0f); - ScorerSupplier supplier = weight.scorerSupplier(leafContext); - BulkScorer bulkScorer = supplier.bulkScorer(); - - assertNotNull(bulkScorer); - - // Test bulk scoring with collection - List collectedDocs = new ArrayList<>(); - LeafCollector collector = new LeafCollector() { - @Override - public void setScorer(Scorable scorer) throws IOException {} - - @Override - public void collect(int doc) throws IOException { - collectedDocs.add(doc); - } - }; - - int result = bulkScorer.score(collector, null, 0, Integer.MAX_VALUE); - - // Should collect documents and potentially expand windows - assertTrue("Should collect some documents", collectedDocs.size() > 0); - assertTrue("Should collect up to 10k documents or exhaust", collectedDocs.size() <= 10000); - } - } - } - } - - // Integration test validating hit count and accuracy - public void testApproximateResultsValidation() throws IOException { - try (Directory directory = newDirectory()) { - try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) { - int numDocs = 20000; - for (int i = 0; i < numDocs; i++) { - Document doc = new Document(); - int field1Value = i % 10; - int field2Value = (i * 2) % 10; - doc.add(new IntPoint("field1", field1Value)); - doc.add(new IntPoint("field2", field2Value)); - doc.add(new NumericDocValuesField("field1", field1Value)); - doc.add(new NumericDocValuesField("field2", field2Value)); - doc.add(new StoredField("field1", field1Value)); - doc.add(new StoredField("field2", field2Value)); - iw.addDocument(doc); - } - iw.flush(); - - try (IndexReader reader = iw.getReader()) { - - SearchContext searchContext = mock(SearchContext.class); - IndexShard indexShard = mock(IndexShard.class); - when(searchContext.indexShard()).thenReturn(indexShard); - SearchOperationListener searchOperationListener = new SearchOperationListener() { - }; - when(indexShard.getSearchOperationListener()).thenReturn(searchOperationListener); - when(searchContext.bucketCollectorProcessor()).thenReturn(new BucketCollectorProcessor()); - when(searchContext.asLocalBucketCountThresholds(any())).thenCallRealMethod(); - - // ContextIndexSearcher searcher = mock(ContextIndexSearcher.class); - ContextIndexSearcher searcher = new ContextIndexSearcher( - reader, - IndexSearcher.getDefaultSimilarity(), - IndexSearcher.getDefaultQueryCache(), - IndexSearcher.getDefaultQueryCachingPolicy(), - true, - mock(ExecutorService.class), - searchContext - ); - - searcher.addQueryCancellation(() -> {}); - - int lower1 = 2; - int upper1 = 5; - int lower2 = 4; - int upper2 = 5; - - // Create approximate query + // Create approximate queries directly ApproximatePointRangeQuery approxQuery1 = new ApproximatePointRangeQuery( "field1", - IntPoint.pack(new int[] { lower1 }).bytes, - IntPoint.pack(new int[] { upper1 }).bytes, + IntPoint.pack(new int[] { 100 }).bytes, + IntPoint.pack(new int[] { 900 }).bytes, 1, ApproximatePointRangeQuery.INT_FORMAT ); ApproximatePointRangeQuery approxQuery2 = new ApproximatePointRangeQuery( "field2", - IntPoint.pack(new int[] { lower2 }).bytes, - IntPoint.pack(new int[] { upper2 }).bytes, + IntPoint.pack(new int[] { 200 }).bytes, + IntPoint.pack(new int[] { 1800 }).bytes, 1, ApproximatePointRangeQuery.INT_FORMAT ); - BooleanQuery approximateBoolQuery = new BooleanQuery.Builder().add(approxQuery1, BooleanClause.Occur.FILTER) + BooleanQuery boolQuery = new BooleanQuery.Builder().add(approxQuery1, BooleanClause.Occur.FILTER) .add(approxQuery2, BooleanClause.Occur.FILTER) .build(); - ApproximateBooleanQuery approximateQuery = new ApproximateBooleanQuery(approximateBoolQuery); - - // Create exact query (regular Lucene BooleanQuery) - BooleanQuery exactBoolQuery = new BooleanQuery.Builder().add( - IntPoint.newRangeQuery("field1", lower1, upper1), - BooleanClause.Occur.FILTER - ).add(IntPoint.newRangeQuery("field2", lower2, upper2), BooleanClause.Occur.FILTER).build(); - - TopScoreDocCollector collector = new TopScoreDocCollectorManager(10001, 10001).newCollector(); - - searcher.search(approximateQuery, collector); - - // Search with both queries - TopDocs approximateDocs = collector.topDocs(); - - TopScoreDocCollector collectorExact = new TopScoreDocCollectorManager(10001, 10001).newCollector(); - - searcher.search(exactBoolQuery, collectorExact); - - // Search with both queries - TopDocs exactDocs = collectorExact.topDocs(); - - System.out.println("Exact hits: " + exactDocs.totalHits.value()); - System.out.println("Approximate hits: " + approximateDocs.totalHits.value()); - System.out.println("approximate score docs length: " + approximateDocs.scoreDocs.length); - // Validate hit count logic - if (exactDocs.totalHits.value() <= 10000) { - assertEquals( - "When exact results ≤ 10k, approximate should match exactly", - exactDocs.totalHits.value(), - approximateDocs.totalHits.value() - ); - } else { - assertEquals( - "Approximate should return exactly 10k hits when exact > 10k", - 10000, - approximateDocs.totalHits.value() - ); - } - - // Validate hit accuracy - each returned doc should match the query criteria - StoredFields storedFields = reader.storedFields(); - for (int i = 0; i < approximateDocs.scoreDocs.length; i++) { - int docId = approximateDocs.scoreDocs[i].doc; - Document doc = storedFields.document(docId); - - int field1Value = doc.getField("field1").numericValue().intValue(); - int field2Value = doc.getField("field2").numericValue().intValue(); - - assertTrue( - "field1 should be in range [" + lower1 + ", " + upper1 + "], got: " + field1Value, - field1Value >= lower1 && field1Value <= upper1 - ); - assertTrue( - "field2 should be in range [" + lower2 + ", " + upper2 + "], got: " + field2Value, - field2Value >= lower2 && field2Value <= upper2 - ); - } - } - } - } - } - - // Test window size heuristic with different cost scenarios - public void testWindowSizeHeuristic() throws IOException { - try (Directory directory = newDirectory()) { - try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) { - for (int i = 0; i < 1000; i++) { - Document doc = new Document(); - doc.add(new IntPoint("field1", i)); - doc.add(new IntPoint("field2", i * 2)); - iw.addDocument(doc); - } - iw.flush(); - - try (IndexReader reader = iw.getReader()) { - IndexSearcher searcher = new IndexSearcher(reader); - LeafReaderContext leafContext = reader.leaves().get(0); - - BooleanQuery boolQuery = new BooleanQuery.Builder().add( - IntPoint.newRangeQuery("field1", 100, 900), - BooleanClause.Occur.FILTER - ).add(IntPoint.newRangeQuery("field2", 200, 1800), BooleanClause.Occur.FILTER).build(); ApproximateBooleanQuery query = new ApproximateBooleanQuery(boolQuery); Weight weight = query.createWeight(searcher, ScoreMode.COMPLETE, 1.0f); ApproximateBooleanScorerSupplier supplier = (ApproximateBooleanScorerSupplier) weight.scorerSupplier(leafContext); assertNotNull(supplier); - - // Test that cost calculation works - long cost = supplier.cost(); - assertTrue("Cost should be positive", cost > 0); } } } @@ -1055,12 +923,14 @@ public void testSparseDataDistribution() throws IOException { doc.add(new IntPoint(fieldName2, statusCode)); doc.add(new NumericDocValuesField(fieldName1, timestamp)); doc.add(new NumericDocValuesField(fieldName2, statusCode)); + doc.add(new StoredField(fieldName1, timestamp)); + doc.add(new StoredField(fieldName2, statusCode)); iw.addDocument(doc); } iw.flush(); try (IndexReader reader = iw.getReader()) { - IndexSearcher searcher = new IndexSearcher(reader); + ContextIndexSearcher searcher = createContextIndexSearcher(reader); // Test query for specific time range and status codes testApproximateQueryValidation(searcher, fieldName1, fieldName2, 10000, 50000, 200, 500, 100); @@ -1088,6 +958,8 @@ public void testDenseDataDistribution() throws IOException { doc.add(new IntPoint(fieldName2, distance)); doc.add(new NumericDocValuesField(fieldName1, fare)); doc.add(new NumericDocValuesField(fieldName2, distance)); + doc.add(new StoredField(fieldName1, fare)); + doc.add(new StoredField(fieldName2, distance)); iw.addDocument(doc); } } @@ -1095,7 +967,7 @@ public void testDenseDataDistribution() throws IOException { iw.flush(); try (IndexReader reader = iw.getReader()) { - IndexSearcher searcher = new IndexSearcher(reader); + ContextIndexSearcher searcher = createContextIndexSearcher(reader); // Test queries for different fare and distance ranges testApproximateQueryValidation(searcher, fieldName1, fieldName2, 1000, 3000, 5, 25, 200); @@ -1106,7 +978,7 @@ public void testDenseDataDistribution() throws IOException { } public void testApproximateQueryValidation( - IndexSearcher searcher, + ContextIndexSearcher searcher, String field1, String field2, int lower1, @@ -1115,34 +987,30 @@ public void testApproximateQueryValidation( int upper2, int size ) throws IOException { - // Test with approximate query - ApproximateScoreQuery approxQuery1 = new ApproximateScoreQuery( - IntPoint.newRangeQuery(field1, lower1, upper1), - new ApproximatePointRangeQuery( - field1, - IntPoint.pack(new int[] { lower1 }).bytes, - IntPoint.pack(new int[] { upper1 }).bytes, - 1, - ApproximatePointRangeQuery.INT_FORMAT - ) + // Create approximate query using ApproximatePointRangeQuery directly + ApproximatePointRangeQuery approxQuery1 = new ApproximatePointRangeQuery( + field1, + IntPoint.pack(new int[] { lower1 }).bytes, + IntPoint.pack(new int[] { upper1 }).bytes, + 1, + ApproximatePointRangeQuery.INT_FORMAT ); - ApproximateScoreQuery approxQuery2 = new ApproximateScoreQuery( - IntPoint.newRangeQuery(field2, lower2, upper2), - new ApproximatePointRangeQuery( - field2, - IntPoint.pack(new int[] { lower2 }).bytes, - IntPoint.pack(new int[] { upper2 }).bytes, - 1, - ApproximatePointRangeQuery.INT_FORMAT - ) + ApproximatePointRangeQuery approxQuery2 = new ApproximatePointRangeQuery( + field2, + IntPoint.pack(new int[] { lower2 }).bytes, + IntPoint.pack(new int[] { upper2 }).bytes, + 1, + ApproximatePointRangeQuery.INT_FORMAT ); BooleanQuery boolQuery = new BooleanQuery.Builder().add(approxQuery1, BooleanClause.Occur.FILTER) .add(approxQuery2, BooleanClause.Occur.FILTER) .build(); - ApproximateScoreQuery approxQuery = new ApproximateScoreQuery(boolQuery, new ApproximateBooleanQuery(boolQuery)); + ApproximateBooleanQuery approximateQuery = new ApproximateBooleanQuery(boolQuery); - TopDocs approxDocs = searcher.search(approxQuery, size); + TopScoreDocCollector collector = new TopScoreDocCollectorManager(size + 1, size + 1).newCollector(); + searcher.search(approximateQuery, collector); + TopDocs approxDocs = collector.topDocs(); // Validate hit count assertTrue("Approximate query should return at most " + size + " docs", approxDocs.scoreDocs.length <= size); From 511603fc343a5e8e38c7f3917b2298b2972548fb Mon Sep 17 00:00:00 2001 From: Sawan Srivastava Date: Fri, 22 Aug 2025 05:10:24 +0000 Subject: [PATCH 34/38] added more tests Signed-off-by: Sawan Srivastava --- .../ApproximateBooleanQueryTests.java | 320 +++++++++--------- 1 file changed, 156 insertions(+), 164 deletions(-) diff --git a/server/src/test/java/org/opensearch/search/approximate/ApproximateBooleanQueryTests.java b/server/src/test/java/org/opensearch/search/approximate/ApproximateBooleanQueryTests.java index 89c49ffc853ab..fa6515ba7729b 100644 --- a/server/src/test/java/org/opensearch/search/approximate/ApproximateBooleanQueryTests.java +++ b/server/src/test/java/org/opensearch/search/approximate/ApproximateBooleanQueryTests.java @@ -240,9 +240,6 @@ public void testApproximateVsExactResults() throws IOException { TopDocs approximateDocs = searcher.search(approximateQuery, 1000); TopDocs exactDocs = searcher.search(exactQuery, 1000); - System.out.println("Exact docs total hits: " + exactDocs.totalHits.value()); - System.out.println("Approx docs total hits: " + approximateDocs.totalHits.value()); - // Results should be identical when approximation is not triggered // or when we collect all available documents if (exactDocs.totalHits.value() <= 1000) { @@ -668,59 +665,57 @@ public void testNestedQueryWithMustNotClause() { } // Test BulkScorer with large dataset to trigger windowed expansion - // public void testBulkScorerWindowedExpansion() throws IOException { - // try (Directory directory = newDirectory()) { - // try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) { - // int numDocs = 20000; - // for (int i = 0; i < numDocs; i++) { - // Document doc = new Document(); - // doc.add(new IntPoint("field1", i)); - // doc.add(new IntPoint("field2", i % 1000)); // Create dense overlapping ranges - // doc.add(new NumericDocValuesField("field1", i)); - // doc.add(new NumericDocValuesField("field2", i % 1000)); - // doc.add(new StoredField("field1", i)); - // doc.add(new StoredField("field2", i % 1000)); - // iw.addDocument(doc); - // } - // iw.flush(); - // - // try (IndexReader reader = iw.getReader()) { - // ContextIndexSearcher searcher = createContextIndexSearcher(reader); - // - // // Create approximate queries directly - // ApproximatePointRangeQuery approxQuery1 = new ApproximatePointRangeQuery( - // "field1", - // IntPoint.pack(new int[] { 1000 }).bytes, - // IntPoint.pack(new int[] { 20000 }).bytes, - // 1, - // ApproximatePointRangeQuery.INT_FORMAT - // ); - // ApproximatePointRangeQuery approxQuery2 = new ApproximatePointRangeQuery( - // "field2", - // IntPoint.pack(new int[] { 100 }).bytes, - // IntPoint.pack(new int[] { 900 }).bytes, - // 1, - // ApproximatePointRangeQuery.INT_FORMAT - // ); - // - // BooleanQuery boolQuery = new BooleanQuery.Builder().add(approxQuery1, BooleanClause.Occur.FILTER) - // .add(approxQuery2, BooleanClause.Occur.FILTER) - // .build(); - // ApproximateBooleanQuery query = new ApproximateBooleanQuery(boolQuery); - // - // TopScoreDocCollector collector = new TopScoreDocCollectorManager(10001, 10001).newCollector(); - // searcher.search(query, collector); - // TopDocs docs = collector.topDocs(); - // - // System.out.println("ScoreDocs length: "+docs.scoreDocs.length); - // System.out.println("total hits value" + docs.totalHits.value()); - // // Should collect documents and potentially expand windows - // assertTrue("Should collect some documents", docs.scoreDocs.length > 0); - // assertTrue("Should collect up to 10k documents or exhaust", docs.scoreDocs.length <= 10001); - // } - // } - // } - // } + public void testBulkScorerWindowedExpansion() throws IOException { + try (Directory directory = newDirectory()) { + try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) { + int numDocs = 20000; + for (int i = 0; i < numDocs; i++) { + Document doc = new Document(); + doc.add(new IntPoint("field1", i)); + doc.add(new IntPoint("field2", i % 1000)); // Create dense overlapping ranges + doc.add(new NumericDocValuesField("field1", i)); + doc.add(new NumericDocValuesField("field2", i % 1000)); + doc.add(new StoredField("field1", i)); + doc.add(new StoredField("field2", i % 1000)); + iw.addDocument(doc); + } + iw.flush(); + + try (IndexReader reader = iw.getReader()) { + ContextIndexSearcher searcher = createContextIndexSearcher(reader); + + // Create approximate queries directly + ApproximatePointRangeQuery approxQuery1 = new ApproximatePointRangeQuery( + "field1", + IntPoint.pack(new int[] { 1000 }).bytes, + IntPoint.pack(new int[] { 20000 }).bytes, + 1, + ApproximatePointRangeQuery.INT_FORMAT + ); + ApproximatePointRangeQuery approxQuery2 = new ApproximatePointRangeQuery( + "field2", + IntPoint.pack(new int[] { 100 }).bytes, + IntPoint.pack(new int[] { 900 }).bytes, + 1, + ApproximatePointRangeQuery.INT_FORMAT + ); + + BooleanQuery boolQuery = new BooleanQuery.Builder().add(approxQuery1, BooleanClause.Occur.FILTER) + .add(approxQuery2, BooleanClause.Occur.FILTER) + .build(); + ApproximateBooleanQuery query = new ApproximateBooleanQuery(boolQuery); + + TopScoreDocCollector collector = new TopScoreDocCollectorManager(10001, 10001).newCollector(); + searcher.search(query, collector); + TopDocs docs = collector.topDocs(); + + // Should collect documents and potentially expand windows + assertTrue("Should collect some documents", docs.scoreDocs.length > 0); + assertTrue("Should collect up to 10k documents or exhaust", docs.scoreDocs.length <= 10001); + } + } + } + } /** * Creates a ContextIndexSearcher with properly mocked SearchContext for testing. @@ -750,113 +745,110 @@ private ContextIndexSearcher createContextIndexSearcher(IndexReader reader) thro } // // Integration test validating hit count and accuracy - // public void testApproximateResultsValidation() throws IOException { - // try (Directory directory = newDirectory()) { - // try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) { - // int numDocs = 20000; - // for (int i = 0; i < numDocs; i++) { - // Document doc = new Document(); - // int field1Value = i % 1000; // Values: 0-999 (1000 unique values) - // int field2Value = i % 500; // Values: 0-499 (500 unique values) - // doc.add(new IntPoint("field1", field1Value)); - // doc.add(new IntPoint("field2", field2Value)); - // doc.add(new NumericDocValuesField("field1", field1Value)); - // doc.add(new NumericDocValuesField("field2", field2Value)); - // doc.add(new StoredField("field1", field1Value)); - // doc.add(new StoredField("field2", field2Value)); - // iw.addDocument(doc); - // } - // iw.flush(); - // - // try (IndexReader reader = iw.getReader()) { - // ContextIndexSearcher searcher = createContextIndexSearcher(reader); - // - // int lower1 = 100; - // int upper1 = 200; - // int lower2 = 50; - // int upper2 = 150; - // - // // Create approximate query - // ApproximatePointRangeQuery approxQuery1 = new ApproximatePointRangeQuery( - // "field1", - // IntPoint.pack(new int[] { lower1 }).bytes, - // IntPoint.pack(new int[] { upper1 }).bytes, - // 1, - // ApproximatePointRangeQuery.INT_FORMAT - // ); - // ApproximatePointRangeQuery approxQuery2 = new ApproximatePointRangeQuery( - // "field2", - // IntPoint.pack(new int[] { lower2 }).bytes, - // IntPoint.pack(new int[] { upper2 }).bytes, - // 1, - // ApproximatePointRangeQuery.INT_FORMAT - // ); - // - // BooleanQuery approximateBoolQuery = new BooleanQuery.Builder().add(approxQuery1, BooleanClause.Occur.FILTER) - // .add(approxQuery2, BooleanClause.Occur.FILTER) - // .build(); - // ApproximateBooleanQuery approximateQuery = new ApproximateBooleanQuery(approximateBoolQuery); - // - // // Create exact query (regular Lucene BooleanQuery) - // BooleanQuery exactBoolQuery = new BooleanQuery.Builder().add( - // IntPoint.newRangeQuery("field1", lower1, upper1), - // BooleanClause.Occur.FILTER - // ).add(IntPoint.newRangeQuery("field2", lower2, upper2), BooleanClause.Occur.FILTER).build(); - // - // TopScoreDocCollector collector = new TopScoreDocCollectorManager(10001, 10001).newCollector(); - // - // searcher.search(approximateQuery, collector); - // - // // Search with both queries - // TopDocs approximateDocs = collector.topDocs(); - // - // TopScoreDocCollector collectorExact = new TopScoreDocCollectorManager(10001, 10001).newCollector(); - // - // searcher.search(exactBoolQuery, collectorExact); - // - // // Search with both queries - // TopDocs exactDocs = collectorExact.topDocs(); - // - // System.out.println("Exact hits: " + exactDocs.totalHits.value()); - // System.out.println("Approximate hits: " + approximateDocs.totalHits.value()); - // System.out.println("approximate score docs length: " + approximateDocs.scoreDocs.length); - // // Validate hit count logic - // if (exactDocs.totalHits.value() <= 10000) { - // assertEquals( - // "When exact results ≤ 10k, approximate should match exactly", - // exactDocs.totalHits.value(), - // approximateDocs.totalHits.value() - // ); - // } else { - // assertEquals( - // "Approximate should return exactly 10k hits when exact > 10k", - // 10000, - // approximateDocs.totalHits.value() - // ); - // } - // - // // Validate hit accuracy - each returned doc should match the query criteria - // StoredFields storedFields = reader.storedFields(); - // for (int i = 0; i < approximateDocs.scoreDocs.length; i++) { - // int docId = approximateDocs.scoreDocs[i].doc; - // Document doc = storedFields.document(docId); - // - // int field1Value = doc.getField("field1").numericValue().intValue(); - // int field2Value = doc.getField("field2").numericValue().intValue(); - // - // assertTrue( - // "field1 should be in range [" + lower1 + ", " + upper1 + "], got: " + field1Value, - // field1Value >= lower1 && field1Value <= upper1 - // ); - // assertTrue( - // "field2 should be in range [" + lower2 + ", " + upper2 + "], got: " + field2Value, - // field2Value >= lower2 && field2Value <= upper2 - // ); - // } - // } - // } - // } - // } + public void testApproximateResultsValidation() throws IOException { + try (Directory directory = newDirectory()) { + try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) { + int numDocs = 20000; + for (int i = 0; i < numDocs; i++) { + Document doc = new Document(); + int field1Value = i % 1000; // Values: 0-999 (1000 unique values) + int field2Value = i % 500; // Values: 0-499 (500 unique values) + doc.add(new IntPoint("field1", field1Value)); + doc.add(new IntPoint("field2", field2Value)); + doc.add(new NumericDocValuesField("field1", field1Value)); + doc.add(new NumericDocValuesField("field2", field2Value)); + doc.add(new StoredField("field1", field1Value)); + doc.add(new StoredField("field2", field2Value)); + iw.addDocument(doc); + } + iw.flush(); + + try (IndexReader reader = iw.getReader()) { + ContextIndexSearcher searcher = createContextIndexSearcher(reader); + + int lower1 = 100; + int upper1 = 200; + int lower2 = 50; + int upper2 = 150; + + // Create approximate query + ApproximatePointRangeQuery approxQuery1 = new ApproximatePointRangeQuery( + "field1", + IntPoint.pack(new int[] { lower1 }).bytes, + IntPoint.pack(new int[] { upper1 }).bytes, + 1, + ApproximatePointRangeQuery.INT_FORMAT + ); + ApproximatePointRangeQuery approxQuery2 = new ApproximatePointRangeQuery( + "field2", + IntPoint.pack(new int[] { lower2 }).bytes, + IntPoint.pack(new int[] { upper2 }).bytes, + 1, + ApproximatePointRangeQuery.INT_FORMAT + ); + + BooleanQuery approximateBoolQuery = new BooleanQuery.Builder().add(approxQuery1, BooleanClause.Occur.FILTER) + .add(approxQuery2, BooleanClause.Occur.FILTER) + .build(); + ApproximateBooleanQuery approximateQuery = new ApproximateBooleanQuery(approximateBoolQuery); + + // Create exact query (regular Lucene BooleanQuery) + BooleanQuery exactBoolQuery = new BooleanQuery.Builder().add( + IntPoint.newRangeQuery("field1", lower1, upper1), + BooleanClause.Occur.FILTER + ).add(IntPoint.newRangeQuery("field2", lower2, upper2), BooleanClause.Occur.FILTER).build(); + + TopScoreDocCollector collector = new TopScoreDocCollectorManager(10001, 10001).newCollector(); + + searcher.search(approximateQuery, collector); + + // Search with both queries + TopDocs approximateDocs = collector.topDocs(); + + TopScoreDocCollector collectorExact = new TopScoreDocCollectorManager(10001, 10001).newCollector(); + + searcher.search(exactBoolQuery, collectorExact); + + // Search with both queries + TopDocs exactDocs = collectorExact.topDocs(); + + // Validate hit count logic + if (exactDocs.totalHits.value() <= 10000) { + assertEquals( + "When exact results ≤ 10k, approximate should match exactly", + exactDocs.totalHits.value(), + approximateDocs.totalHits.value() + ); + } else { + assertEquals( + "Approximate should return exactly 10k hits when exact > 10k", + 10000, + approximateDocs.totalHits.value() + ); + } + + // Validate hit accuracy - each returned doc should match the query criteria + StoredFields storedFields = reader.storedFields(); + for (int i = 0; i < approximateDocs.scoreDocs.length; i++) { + int docId = approximateDocs.scoreDocs[i].doc; + Document doc = storedFields.document(docId); + + int field1Value = doc.getField("field1").numericValue().intValue(); + int field2Value = doc.getField("field2").numericValue().intValue(); + + assertTrue( + "field1 should be in range [" + lower1 + ", " + upper1 + "], got: " + field1Value, + field1Value >= lower1 && field1Value <= upper1 + ); + assertTrue( + "field2 should be in range [" + lower2 + ", " + upper2 + "], got: " + field2Value, + field2Value >= lower2 && field2Value <= upper2 + ); + } + } + } + } + } // Test window size heuristic with different cost scenarios public void testWindowSizeHeuristic() throws IOException { @@ -1013,7 +1005,7 @@ public void testApproximateQueryValidation( TopDocs approxDocs = collector.topDocs(); // Validate hit count - assertTrue("Approximate query should return at most " + size + " docs", approxDocs.scoreDocs.length <= size); + assertTrue("Approximate query should return at most " + size + 1 + " docs", approxDocs.scoreDocs.length <= size + 1); assertTrue("Should not exceed 10k hits", approxDocs.totalHits.value() <= 10000); // Validate hit accuracy - each returned doc should match the query criteria From 75facd39adf6af03386b8a475d1ce2a6e7f05e2a Mon Sep 17 00:00:00 2001 From: Sawan Srivastava Date: Fri, 22 Aug 2025 05:29:51 +0000 Subject: [PATCH 35/38] corrected canApproximate tests Signed-off-by: Sawan Srivastava --- .../ApproximateBooleanQueryTests.java | 37 +++++++++++++++---- 1 file changed, 30 insertions(+), 7 deletions(-) diff --git a/server/src/test/java/org/opensearch/search/approximate/ApproximateBooleanQueryTests.java b/server/src/test/java/org/opensearch/search/approximate/ApproximateBooleanQueryTests.java index fa6515ba7729b..f458e3d7ddc6b 100644 --- a/server/src/test/java/org/opensearch/search/approximate/ApproximateBooleanQueryTests.java +++ b/server/src/test/java/org/opensearch/search/approximate/ApproximateBooleanQueryTests.java @@ -333,7 +333,18 @@ public void testAllFilterClausesCanApproximate() { } public void testSingleClauseMustCanApproximate() { - BooleanQuery boolQuery = new BooleanQuery.Builder().add(IntPoint.newRangeQuery("field", 1, 100), BooleanClause.Occur.MUST).build(); + ApproximateScoreQuery approxQuery = new ApproximateScoreQuery( + IntPoint.newRangeQuery("field", 1, 100), + new ApproximatePointRangeQuery( + "field", + IntPoint.pack(new int[] { 1 }).bytes, + IntPoint.pack(new int[] { 100 }).bytes, + 1, + ApproximatePointRangeQuery.INT_FORMAT + ) + ); + + BooleanQuery boolQuery = new BooleanQuery.Builder().add(approxQuery, BooleanClause.Occur.MUST).build(); ApproximateBooleanQuery query = new ApproximateBooleanQuery(boolQuery); SearchContext mockContext = mock(SearchContext.class); @@ -341,8 +352,10 @@ public void testSingleClauseMustCanApproximate() { when(mockContext.aggregations()).thenReturn(null); when(mockContext.highlight()).thenReturn(null); + approxQuery.setContext(mockContext); + // Single clause with MUST should return false (not handled by current logic) - assertFalse("Single MUST clause should not be approximatable", query.canApproximate(mockContext)); + assertTrue("Single MUST clause should be approximatable", query.canApproximate(mockContext)); } public void testSingleClauseShouldCanApproximate() { @@ -366,8 +379,17 @@ public void testSingleClauseShouldCanApproximate() { } public void testSingleClauseFilterCanApproximate() { - BooleanQuery boolQuery = new BooleanQuery.Builder().add(IntPoint.newRangeQuery("field", 1, 100), BooleanClause.Occur.FILTER) - .build(); + ApproximateScoreQuery approxQuery = new ApproximateScoreQuery( + IntPoint.newRangeQuery("field", 1, 100), + new ApproximatePointRangeQuery( + "field", + IntPoint.pack(new int[] { 1 }).bytes, + IntPoint.pack(new int[] { 100 }).bytes, + 1, + ApproximatePointRangeQuery.INT_FORMAT + ) + ); + BooleanQuery boolQuery = new BooleanQuery.Builder().add(approxQuery, BooleanClause.Occur.FILTER).build(); ApproximateBooleanQuery query = new ApproximateBooleanQuery(boolQuery); SearchContext mockContext = mock(SearchContext.class); @@ -375,8 +397,10 @@ public void testSingleClauseFilterCanApproximate() { when(mockContext.aggregations()).thenReturn(null); when(mockContext.highlight()).thenReturn(null); - // Single clause with FILTER should return false (not MUST_NOT, but not handled) - assertFalse("Single FILTER clause should not be approximatable", query.canApproximate(mockContext)); + approxQuery.setContext(mockContext); + + // Single clause with FILTER should approximate + assertTrue("Single FILTER clause should be approximatable", query.canApproximate(mockContext)); } // Test BoolQueryBuilder pattern: Single clause WITH ApproximateScoreQuery wrapper @@ -518,7 +542,6 @@ public void testNestedMultiClauseWithApproximateScoreQuery() { when(mockContext.aggregations()).thenReturn(null); when(mockContext.highlight()).thenReturn(null); - // Should delegate to nested ApproximateBooleanQuery and return true assertFalse("Nested multi-FILTER clause should not be approximatable", outerQuery.canApproximate(mockContext)); } From 90a4126be4c10bc504a3621f572ee06a09461999 Mon Sep 17 00:00:00 2001 From: Sawan Srivastava Date: Fri, 22 Aug 2025 14:51:05 +0000 Subject: [PATCH 36/38] added proper isTopLevel + desc sort Signed-off-by: Sawan Srivastava --- .../approximate/ApproximateBooleanQuery.java | 13 ++ .../ApproximateBooleanScorerSupplier.java | 38 +---- .../ApproximatePointRangeQuery.java | 157 +++++++----------- .../approximate/ApproximateScoreQuery.java | 13 +- .../ApproximateBooleanQueryTests.java | 35 +++- 5 files changed, 118 insertions(+), 138 deletions(-) diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java index 4b25e23c13778..0b61bfef8d430 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanQuery.java @@ -33,6 +33,7 @@ public class ApproximateBooleanQuery extends ApproximateQuery { public final BooleanQuery boolQuery; private final int size; private final List clauses; + private boolean isTopLevel = true; // Default to true, set to false when nested in boolean query public ApproximateBooleanQuery(BooleanQuery boolQuery) { this(boolQuery, SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO); @@ -48,6 +49,14 @@ public BooleanQuery getBooleanQuery() { return boolQuery; } + public boolean isTopLevel() { + return this.isTopLevel; + } + + public void setTopLevel(boolean isTopLevel) { + this.isTopLevel = isTopLevel; + } + @Override public Query rewrite(IndexSearcher indexSearcher) throws IOException { return super.rewrite(indexSearcher); @@ -89,6 +98,10 @@ protected boolean canApproximate(SearchContext context) { return false; } + if (!isTopLevel) { + return false; + } + // For single clause boolean queries, check if the clause can be approximated if (clauses.size() == 1 && clauses.get(0).occur() != BooleanClause.Occur.MUST_NOT) { // If the clause is already an ApproximateScoreQuery, we can approximate + set context diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java index 501f530eca456..31cb9a400f4aa 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateBooleanScorerSupplier.java @@ -82,42 +82,8 @@ public void setScalingWindowFactor(int factor) { */ @Override public Scorer get(long leadCost) throws IOException { - if (clauseWeights.isEmpty()) { - return null; - } - - // Create appropriate iterators for each clause - List clauseIterators = new ArrayList<>(clauseWeights.size()); - for (int i = 0; i < clauseWeights.size(); i++) { - // Use regular DocIdSetIterator for non-approximatable queries - clauseIterators.add(cachedSuppliers.get(i).get(leadCost).iterator()); - } - - // Use Lucene's ConjunctionUtils to create the conjunction - DocIdSetIterator conjunctionDISI = ConjunctionUtils.intersectIterators(clauseIterators); - - // Create a simple scorer that wraps the conjunction - return new Scorer() { - @Override - public DocIdSetIterator iterator() { - return conjunctionDISI; - } - - @Override - public float score() throws IOException { - return 0.0f; - } - - @Override - public float getMaxScore(int upTo) throws IOException { - return 0.0f; - } - - @Override - public int docID() { - return conjunctionDISI.docID(); - } - }; + // should not get called in a non-top level query + return null; } /** diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java index 07d51df99b276..2755c35a9d4b2 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java @@ -57,6 +57,7 @@ public class ApproximatePointRangeQuery extends ApproximateQuery { private SortOrder sortOrder; public PointRangeQuery pointRangeQuery; private final Function valueToString; + private boolean isTopLevel = true; // Default to true, set to false when nested in boolean query private boolean hasBeenFullyTraversed = false; @@ -98,6 +99,14 @@ public void setSize(int size) { this.size = size; } + public boolean isTopLevel() { + return this.isTopLevel; + } + + public void setTopLevel(boolean isTopLevel) { + this.isTopLevel = isTopLevel; + } + public SortOrder getSortOrder() { return this.sortOrder; } @@ -305,51 +314,6 @@ public void intersectLeft(PointValues.IntersectVisitor visitor, PointValues.Poin pointTree.moveToParent(); } - public void intersectLeftIterative( - PointValues.IntersectVisitor visitor, - PointValues.PointTree pointTree, - long[] docCount, - BKDState state - ) throws IOException { - - while (true) { - PointValues.Relation compare = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue()); - if (compare == PointValues.Relation.CELL_INSIDE_QUERY) { - // Check if processing this entire subtree would exceed our limit - long subtreeSize = pointTree.size(); - if (docCount[0] + subtreeSize > size) { - // Too big - need to process children individually - if (pointTree.moveToChild()) { - continue; // Process children one by one - } - } - // Safe to process entire subtree - pointTree.visitDocIDs(visitor); - } else if (compare == PointValues.Relation.CELL_CROSSES_QUERY) { - // The cell crosses the shape boundary, or the cell fully contains the query, so we fall - // through and do full filtering: - if (pointTree.moveToChild()) { - continue; - } - // Leaf node; scan and filter all points in this block: - pointTree.visitDocValues(visitor); - } - // position ourself to next place - while (pointTree.moveToSibling() == false) { - if (pointTree.moveToParent() == false) { - // Reached true root - entire BKD tree traversal is complete - state.setExhausted(true); - return; - } - } - // if (docCount[0] >= size) { - // state.setPointTree(pointTree); - // state.setInProgress(true); - // return; - // } - } - } - // custom intersect visitor to walk the right of tree (from rightmost leaf going left) public void intersectRight(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree, long[] docCount) throws IOException { @@ -407,7 +371,7 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti } // values.size(): total points indexed, In most cases: values.size() ≈ number of documents (assuming single-valued fields) - if (size > values.size() && context.isTopLevel) { + if (size > values.size() && isTopLevel) { return pointRangeQueryWeight.scorerSupplier(context); } else { if (sortOrder == null || sortOrder.equals(SortOrder.ASC)) { @@ -420,7 +384,7 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti @Override public Scorer get(long leadCost) throws IOException { - if (!context.isTopLevel) { + if (!isTopLevel) { // Use leadCost as dynamic size if it's reasonable, otherwise use original size int dynamicSize = (leadCost > 0 && leadCost < Integer.MAX_VALUE) ? (int) leadCost : size; return getWithSize(dynamicSize); @@ -459,7 +423,7 @@ public Scorer getWithSize(int dynamicSize) throws IOException { @Override public long cost() { if (cost == -1) { - if (context.isTopLevel) { + if (isTopLevel) { // Computing the cost may be expensive, so only do it if necessary cost = values.estimateDocCount(visitor); assert cost >= 0; @@ -469,36 +433,69 @@ public long cost() { } return cost; } - - public int getBKDSize() throws IOException { - return reader.getPointValues(pointRangeQuery.getField()).getDocCount(); - } - }; } else { + // Descending sort - use intersectRight // we need to fetch size + deleted docs since the collector will prune away deleted docs resulting in fewer results // than expected final int deletedDocs = reader.numDeletedDocs(); - size += deletedDocs; - return new ScorerSupplier() { + int adjustedSize = size + deletedDocs; - final DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values); - final PointValues.IntersectVisitor visitor = getIntersectVisitor(result, docCount); + return new ScorerSupplier() { + // Keep a visitor for cost estimation only + DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values); + PointValues.IntersectVisitor visitor = getIntersectVisitor(result, docCount); long cost = -1; + long lastDoc = -1; @Override public Scorer get(long leadCost) throws IOException { - intersectRight(values.getPointTree(), visitor, docCount); - DocIdSetIterator iterator = result.build().iterator(); - return new ConstantScoreScorer(score(), scoreMode, iterator); + if (!isTopLevel) { + // Use leadCost as dynamic size if it's reasonable, otherwise use original size + int dynamicSize = (leadCost > 0 && leadCost < Integer.MAX_VALUE) ? (int) leadCost : adjustedSize; + return getWithSize(dynamicSize); + } else { + // For top-level queries, use standard approach + return getWithSize(adjustedSize); + } + } + + public Scorer getWithSize(int dynamicSize) throws IOException { + // Temporarily update size for this call + int originalSize = size; + size = dynamicSize; + + try { + if (size > values.size()) { + hasBeenFullyTraversed = true; + } + // For windowed approach, create fresh iterator without ResumableDISI state + DocIdSetBuilder freshResult = new DocIdSetBuilder(reader.maxDoc(), values); + long[] freshDocCount = new long[1]; + PointValues.IntersectVisitor freshVisitor = getIntersectVisitor(freshResult, freshDocCount); + + // Always start fresh traversal from root using intersectRight for descending + intersectRight(values.getPointTree(), freshVisitor, freshDocCount); + + DocIdSetIterator iterator = freshResult.build().iterator(); + lastDoc = iterator.docIDRunEnd(); + return new ConstantScoreScorer(score(), scoreMode, iterator); + } finally { + // Restore original size + size = originalSize; + } } @Override public long cost() { if (cost == -1) { - // Computing the cost may be expensive, so only do it if necessary - cost = values.estimateDocCount(visitor); - assert cost >= 0; + if (isTopLevel) { + // Computing the cost may be expensive, so only do it if necessary + cost = values.estimateDocCount(visitor); + assert cost >= 0; + } else { + return lastDoc != -1 ? lastDoc : values.estimateDocCount(visitor); + } } return cost; } @@ -614,38 +611,4 @@ public final String toString(String field) { return sb.toString(); } - - /** - * Class to track the state of BKD tree traversal. - */ - public static class BKDState { - private PointValues.PointTree currentTree; - private boolean isExhausted = false; - private boolean inProgress = false; - - public PointValues.PointTree getPointTree() { - return currentTree; - } - - public void setPointTree(PointValues.PointTree tree) { - this.currentTree = tree; - } - - public boolean isExhausted() { - return this.isExhausted; - } - - public void setExhausted(boolean exhausted) { - this.isExhausted = exhausted; - } - - public boolean isInProgress() { - return this.inProgress; - } - - public void setInProgress(boolean inProgress) { - this.inProgress = inProgress; - } - - } } diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java index 1a8e23a035f27..5a84332e6e08a 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java @@ -59,16 +59,23 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { public void setContext(SearchContext context) { resolvedQuery = approximationQuery.canApproximate(context) ? approximationQuery : originalQuery; + + if ((resolvedQuery instanceof BooleanQuery) || (resolvedQuery instanceof ApproximateBooleanQuery)) { + resolvedQuery = ApproximateBooleanQuery.boolRewrite(resolvedQuery, context.searcher()); + } + if (resolvedQuery instanceof ApproximateBooleanQuery appxBool) { for (BooleanClause boolClause : appxBool.boolQuery.clauses()) { if (boolClause.query() instanceof ApproximateScoreQuery apprxQuery) { + if (apprxQuery.resolvedQuery instanceof ApproximateBooleanQuery boolQuery) { + boolQuery.setTopLevel(false); + } else if (apprxQuery.resolvedQuery instanceof ApproximatePointRangeQuery pointQuery) { + pointQuery.setTopLevel(false); + } apprxQuery.setContext(context); } } } - if ((resolvedQuery instanceof BooleanQuery) || (resolvedQuery instanceof ApproximateBooleanQuery)) { - resolvedQuery = ApproximateBooleanQuery.boolRewrite(resolvedQuery, context.searcher()); - } } @Override diff --git a/server/src/test/java/org/opensearch/search/approximate/ApproximateBooleanQueryTests.java b/server/src/test/java/org/opensearch/search/approximate/ApproximateBooleanQueryTests.java index f458e3d7ddc6b..58e05edbcd177 100644 --- a/server/src/test/java/org/opensearch/search/approximate/ApproximateBooleanQueryTests.java +++ b/server/src/test/java/org/opensearch/search/approximate/ApproximateBooleanQueryTests.java @@ -118,6 +118,10 @@ public void testCanApproximateWithValidFilterClauses() { ) ); + // Set isTopLevel to false since these are nested in boolean query + ((ApproximatePointRangeQuery) approxQuery1.getApproximationQuery()).setTopLevel(false); + ((ApproximatePointRangeQuery) approxQuery2.getApproximationQuery()).setTopLevel(false); + BooleanQuery boolQuery = new BooleanQuery.Builder().add(approxQuery1, BooleanClause.Occur.FILTER) .add(approxQuery2, BooleanClause.Occur.FILTER) .build(); @@ -181,9 +185,9 @@ public void testScorerSupplierCreation() throws IOException { // Test cost estimation assertTrue(supplier.cost() > 0); - // Test scorer creation + // Test scorer creation, scorer should be null since nested ApproximateBooleanQueries shouldn't exist Scorer scorer = supplier.get(1000); - assertNotNull(scorer); + assertNull(scorer); } } } @@ -219,6 +223,7 @@ public void testApproximateVsExactResults() throws IOException { 1, ApproximatePointRangeQuery.INT_FORMAT ); + approxQuery1.setTopLevel(false); ApproximatePointRangeQuery approxQuery2 = new ApproximatePointRangeQuery( "field2", IntPoint.pack(new int[] { lower2 }).bytes, @@ -227,6 +232,8 @@ public void testApproximateVsExactResults() throws IOException { ApproximatePointRangeQuery.INT_FORMAT ); + approxQuery2.setTopLevel(false); + BooleanQuery boolQuery = new BooleanQuery.Builder().add(approxQuery1, BooleanClause.Occur.FILTER) .add(approxQuery2, BooleanClause.Occur.FILTER) .build(); @@ -263,6 +270,8 @@ public void testSingleClauseApproximation() { 1, ApproximatePointRangeQuery.LONG_FORMAT ); + pointQuery.setTopLevel(false); // Set as non-top-level since it's nested + ApproximateScoreQuery scoreQuery = new ApproximateScoreQuery(IntPoint.newRangeQuery("field", 1, 100), pointQuery); BooleanQuery boolQuery = new BooleanQuery.Builder().add(scoreQuery, BooleanClause.Occur.MUST).build(); @@ -314,6 +323,11 @@ public void testAllFilterClausesCanApproximate() { ) ); + // Set isTopLevel to false since these are nested in boolean query + ((ApproximatePointRangeQuery) approxQuery1.getApproximationQuery()).setTopLevel(false); + ((ApproximatePointRangeQuery) approxQuery2.getApproximationQuery()).setTopLevel(false); + ((ApproximatePointRangeQuery) approxQuery3.getApproximationQuery()).setTopLevel(false); + BooleanQuery boolQuery = new BooleanQuery.Builder().add(approxQuery1, BooleanClause.Occur.FILTER) .add(approxQuery2, BooleanClause.Occur.FILTER) .add(approxQuery3, BooleanClause.Occur.FILTER) @@ -344,6 +358,9 @@ public void testSingleClauseMustCanApproximate() { ) ); + // Set isTopLevel to false since it's nested in boolean query + ((ApproximatePointRangeQuery) approxQuery.getApproximationQuery()).setTopLevel(false); + BooleanQuery boolQuery = new BooleanQuery.Builder().add(approxQuery, BooleanClause.Occur.MUST).build(); ApproximateBooleanQuery query = new ApproximateBooleanQuery(boolQuery); @@ -413,6 +430,7 @@ public void testSingleClauseWithApproximateScoreQueryCanApproximate() { 1, ApproximatePointRangeQuery.LONG_FORMAT ); + approxQuery.setTopLevel(false); ApproximateScoreQuery scoreQuery = new ApproximateScoreQuery(IntPoint.newRangeQuery("field", 1, 100), approxQuery); // Test all single clause types (MUST, SHOULD, FILTER) - all should work @@ -442,6 +460,7 @@ public void testSingleClauseMustNotCannotApproximate() { 1, ApproximatePointRangeQuery.LONG_FORMAT ); + approxQuery.setTopLevel(false); ApproximateScoreQuery scoreQuery = new ApproximateScoreQuery(IntPoint.newRangeQuery("field", 1, 100), approxQuery); BooleanQuery boolQuery = new BooleanQuery.Builder().add(scoreQuery, BooleanClause.Occur.MUST_NOT).build(); @@ -465,6 +484,7 @@ public void testNestedSingleClauseWithApproximateScoreQuery() { 1, ApproximatePointRangeQuery.INT_FORMAT ); + innerApproxQuery.setTopLevel(false); ApproximateScoreQuery innerScoreQuery = new ApproximateScoreQuery(IntPoint.newRangeQuery("inner_field", 50, 150), innerApproxQuery); // Inner boolean query (single clause) @@ -482,6 +502,9 @@ public void testNestedSingleClauseWithApproximateScoreQuery() { when(mockContext.aggregations()).thenReturn(null); when(mockContext.highlight()).thenReturn(null); + outerScoreQuery.setContext(mockContext); + innerScoreQuery.setContext(mockContext); + // Should delegate to nested ApproximateBooleanQuery boolean result = outerQuery.canApproximate(mockContext); assertTrue("Nested single clause should follow inner query logic and be approximatable", result); @@ -715,6 +738,7 @@ public void testBulkScorerWindowedExpansion() throws IOException { 1, ApproximatePointRangeQuery.INT_FORMAT ); + approxQuery1.setTopLevel(false); ApproximatePointRangeQuery approxQuery2 = new ApproximatePointRangeQuery( "field2", IntPoint.pack(new int[] { 100 }).bytes, @@ -722,6 +746,7 @@ public void testBulkScorerWindowedExpansion() throws IOException { 1, ApproximatePointRangeQuery.INT_FORMAT ); + approxQuery2.setTopLevel(false); BooleanQuery boolQuery = new BooleanQuery.Builder().add(approxQuery1, BooleanClause.Occur.FILTER) .add(approxQuery2, BooleanClause.Occur.FILTER) @@ -802,6 +827,7 @@ public void testApproximateResultsValidation() throws IOException { 1, ApproximatePointRangeQuery.INT_FORMAT ); + approxQuery1.setTopLevel(false); ApproximatePointRangeQuery approxQuery2 = new ApproximatePointRangeQuery( "field2", IntPoint.pack(new int[] { lower2 }).bytes, @@ -809,6 +835,7 @@ public void testApproximateResultsValidation() throws IOException { 1, ApproximatePointRangeQuery.INT_FORMAT ); + approxQuery2.setTopLevel(false); BooleanQuery approximateBoolQuery = new BooleanQuery.Builder().add(approxQuery1, BooleanClause.Occur.FILTER) .add(approxQuery2, BooleanClause.Occur.FILTER) @@ -897,6 +924,7 @@ public void testWindowSizeHeuristic() throws IOException { 1, ApproximatePointRangeQuery.INT_FORMAT ); + approxQuery1.setTopLevel(false); ApproximatePointRangeQuery approxQuery2 = new ApproximatePointRangeQuery( "field2", IntPoint.pack(new int[] { 200 }).bytes, @@ -904,6 +932,7 @@ public void testWindowSizeHeuristic() throws IOException { 1, ApproximatePointRangeQuery.INT_FORMAT ); + approxQuery2.setTopLevel(false); BooleanQuery boolQuery = new BooleanQuery.Builder().add(approxQuery1, BooleanClause.Occur.FILTER) .add(approxQuery2, BooleanClause.Occur.FILTER) @@ -1010,6 +1039,7 @@ public void testApproximateQueryValidation( 1, ApproximatePointRangeQuery.INT_FORMAT ); + approxQuery1.setTopLevel(false); ApproximatePointRangeQuery approxQuery2 = new ApproximatePointRangeQuery( field2, IntPoint.pack(new int[] { lower2 }).bytes, @@ -1017,6 +1047,7 @@ public void testApproximateQueryValidation( 1, ApproximatePointRangeQuery.INT_FORMAT ); + approxQuery2.setTopLevel(false); BooleanQuery boolQuery = new BooleanQuery.Builder().add(approxQuery1, BooleanClause.Occur.FILTER) .add(approxQuery2, BooleanClause.Occur.FILTER) From 3898863e34c844aadb71345139a0b5c92241d5d1 Mon Sep 17 00:00:00 2001 From: Sawan Srivastava Date: Fri, 22 Aug 2025 15:18:11 +0000 Subject: [PATCH 37/38] gradle check Signed-off-by: Sawan Srivastava From 71855282ac721cff74b839b785bcf9e362cb068b Mon Sep 17 00:00:00 2001 From: Sawan Srivastava Date: Thu, 6 Nov 2025 21:04:35 -0800 Subject: [PATCH 38/38] rename vars + update changelog Signed-off-by: Sawan Srivastava --- CHANGELOG.md | 29 +------------------ .../approximate/ApproximateScoreQuery.java | 12 ++++---- 2 files changed, 7 insertions(+), 34 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d620d2c79ca10..875a67d92444c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,35 +23,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Implement GRPC Search params `Highlight`and `Sort` ([#19868](https://github.com/opensearch-project/OpenSearch/pull/19868)) - Implement GRPC ConstantScoreQuery, FuzzyQuery, MatchBoolPrefixQuery, MatchPhrasePrefix, PrefixQuery, MatchQuery ([#19854](https://github.com/opensearch-project/OpenSearch/pull/19854)) - Add async periodic flush task support for pull-based ingestion ([#19878](https://github.com/opensearch-project/OpenSearch/pull/19878)) +- Multifold Improvement in Multi-Clause Boolean Query, Window Scoring Approach ([#19045](https://github.com/opensearch-project/OpenSearch/pull/19046)) -- Add support for Warm Indices Write Block on Flood Watermark breach ([#18375](https://github.com/opensearch-project/OpenSearch/pull/18375)) -- Add support for custom index name resolver from cluster plugin ([#18593](https://github.com/opensearch-project/OpenSearch/pull/18593)) -- Rename WorkloadGroupTestUtil to WorkloadManagementTestUtil ([#18709](https://github.com/opensearch-project/OpenSearch/pull/18709)) -- Disallow resize for Warm Index, add Parameterized ITs for close in remote store ([#18686](https://github.com/opensearch-project/OpenSearch/pull/18686)) -- Ability to run Code Coverage with Gradle and produce the jacoco reports locally ([#18509](https://github.com/opensearch-project/OpenSearch/issues/18509)) -- [Workload Management] Update logging and Javadoc, rename QueryGroup to WorkloadGroup ([#18711](https://github.com/opensearch-project/OpenSearch/issues/18711)) -- Add NodeResourceUsageStats to ClusterInfo ([#18480](https://github.com/opensearch-project/OpenSearch/issues/18472)) -- Introduce SecureHttpTransportParameters experimental API (to complement SecureTransportParameters counterpart) ([#18572](https://github.com/opensearch-project/OpenSearch/issues/18572)) -- Create equivalents of JSM's AccessController in the java agent ([#18346](https://github.com/opensearch-project/OpenSearch/issues/18346)) -- [WLM] Add WLM mode validation for workload group CRUD requests ([#18652](https://github.com/opensearch-project/OpenSearch/issues/18652)) -- Introduced a new cluster-level API to fetch remote store metadata (segments and translogs) for each shard of an index. ([#18257](https://github.com/opensearch-project/OpenSearch/pull/18257)) -- Add last index request timestamp columns to the `_cat/indices` API. ([10766](https://github.com/opensearch-project/OpenSearch/issues/10766)) -- Introduce a new pull-based ingestion plugin for file-based indexing (for local testing) ([#18591](https://github.com/opensearch-project/OpenSearch/pull/18591)) -- Add support for search pipeline in search and msearch template ([#18564](https://github.com/opensearch-project/OpenSearch/pull/18564)) -- [Workload Management] Modify logging message in WorkloadGroupService ([#18712](https://github.com/opensearch-project/OpenSearch/pull/18712)) -- Add BooleanQuery rewrite moving constant-scoring must clauses to filter clauses ([#18510](https://github.com/opensearch-project/OpenSearch/issues/18510)) -- Add functionality for plugins to inject QueryCollectorContext during QueryPhase ([#18637](https://github.com/opensearch-project/OpenSearch/pull/18637)) -- Add support for non-timing info in profiler ([#18460](https://github.com/opensearch-project/OpenSearch/issues/18460)) -- Extend Approximation Framework to other numeric types ([#18530](https://github.com/opensearch-project/OpenSearch/issues/18530)) -- Add Semantic Version field type mapper and extensive unit tests([#18454](https://github.com/opensearch-project/OpenSearch/pull/18454)) -- Pass index settings to system ingest processor factories. ([#18708](https://github.com/opensearch-project/OpenSearch/pull/18708)) -- Include named queries from rescore contexts in matched_queries array ([#18697](https://github.com/opensearch-project/OpenSearch/pull/18697)) -- Add the configurable limit on rule cardinality ([#18663](https://github.com/opensearch-project/OpenSearch/pull/18663)) -- [Experimental] Start in "clusterless" mode if a clusterless ClusterPlugin is loaded ([#18479](https://github.com/opensearch-project/OpenSearch/pull/18479)) -- [Star-Tree] Add star-tree search related stats ([#18707](https://github.com/opensearch-project/OpenSearch/pull/18707)) -- Add support for plugins to profile information ([#18656](https://github.com/opensearch-project/OpenSearch/pull/18656)) -- Add support for Combined Fields query ([#18724](https://github.com/opensearch-project/OpenSearch/pull/18724)) -- Multifold Improvement in Multi-Clause Boolean Query, Window Scoring Approach ([#19046](https://github.com/opensearch-project/OpenSearch/pull/19046)) ### Changed - Faster `terms` query creation for `keyword` field with index and docValues enabled ([#19350](https://github.com/opensearch-project/OpenSearch/pull/19350)) diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java index 5a84332e6e08a..b561655c39695 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java @@ -64,15 +64,15 @@ public void setContext(SearchContext context) { resolvedQuery = ApproximateBooleanQuery.boolRewrite(resolvedQuery, context.searcher()); } - if (resolvedQuery instanceof ApproximateBooleanQuery appxBool) { - for (BooleanClause boolClause : appxBool.boolQuery.clauses()) { - if (boolClause.query() instanceof ApproximateScoreQuery apprxQuery) { - if (apprxQuery.resolvedQuery instanceof ApproximateBooleanQuery boolQuery) { + if (resolvedQuery instanceof ApproximateBooleanQuery approximateBool) { + for (BooleanClause boolClause : approximateBool.boolQuery.clauses()) { + if (boolClause.query() instanceof ApproximateScoreQuery approximateQuery) { + if (approximateQuery.resolvedQuery instanceof ApproximateBooleanQuery boolQuery) { boolQuery.setTopLevel(false); - } else if (apprxQuery.resolvedQuery instanceof ApproximatePointRangeQuery pointQuery) { + } else if (approximateQuery.resolvedQuery instanceof ApproximatePointRangeQuery pointQuery) { pointQuery.setTopLevel(false); } - apprxQuery.setContext(context); + approximateQuery.setContext(context); } } }