diff --git a/CHANGELOG.md b/CHANGELOG.md index 560bf6b71c7fc..68a587b245849 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -52,6 +52,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - [Star-Tree] Add search support for ip field type ([#18671](https://github.com/opensearch-project/OpenSearch/pull/18671)) - [Derived Source] Add integration of derived source feature across various paths like get/search/recovery ([#18565](https://github.com/opensearch-project/OpenSearch/pull/18565)) - Supporting Scripted Metric Aggregation when reducing aggregations in InternalValueCount and InternalAvg ([18411](https://github.com/opensearch-project/OpenSearch/pull18411))) +- Support `search_after` numeric queries with Approximation Framework ([#18896](https://github.com/opensearch-project/OpenSearch/pull/18896)) ### Changed - Update Subject interface to use CheckedRunnable ([#18570](https://github.com/opensearch-project/OpenSearch/issues/18570)) diff --git a/modules/mapper-extras/src/main/java/org/opensearch/index/mapper/ScaledFloatFieldMapper.java b/modules/mapper-extras/src/main/java/org/opensearch/index/mapper/ScaledFloatFieldMapper.java index cf091f8d03590..3f50398a1121d 100644 --- a/modules/mapper-extras/src/main/java/org/opensearch/index/mapper/ScaledFloatFieldMapper.java +++ b/modules/mapper-extras/src/main/java/org/opensearch/index/mapper/ScaledFloatFieldMapper.java @@ -217,6 +217,22 @@ public byte[] encodePoint(Number value) { return point; } + @Override + public byte[] encodePoint(Object value, boolean roundUp) { + double doubleValue = parse(value); + long scaledValue = Math.round(scale(doubleValue)); + if (roundUp) { + if (scaledValue < Long.MAX_VALUE) { + scaledValue = scaledValue + 1; + } + } else { + if (scaledValue > Long.MIN_VALUE) { + scaledValue = scaledValue - 1; + } + } + return encodePoint(scaledValue); + } + public double getScalingFactor() { return scalingFactor; } diff --git a/server/src/main/java/org/opensearch/index/mapper/DateFieldMapper.java b/server/src/main/java/org/opensearch/index/mapper/DateFieldMapper.java index fe33b7fc3da10..57877f7ce92b1 100644 --- a/server/src/main/java/org/opensearch/index/mapper/DateFieldMapper.java +++ b/server/src/main/java/org/opensearch/index/mapper/DateFieldMapper.java @@ -619,6 +619,23 @@ public byte[] encodePoint(Number value) { return point; } + @Override + public byte[] encodePoint(Object value, boolean roundUp) { + // Always parse with roundUp=false to get consistent date math + // In this method the parseToLong is only used for date math rounding operations + long timestamp = parseToLong(value, false, null, null, null); + if (roundUp) { + if (timestamp < Long.MAX_VALUE) { + timestamp = timestamp + 1; + } + } else { + if (timestamp > Long.MIN_VALUE) { + timestamp = timestamp - 1; + } + } + return encodePoint(timestamp); + } + @Override public Query distanceFeatureQuery(Object origin, String pivot, float boost, QueryShardContext context) { failIfNotIndexedAndNoDocValues(); diff --git a/server/src/main/java/org/opensearch/index/mapper/NumberFieldMapper.java b/server/src/main/java/org/opensearch/index/mapper/NumberFieldMapper.java index 39f1407ecbeaa..6fa91e33df7ae 100644 --- a/server/src/main/java/org/opensearch/index/mapper/NumberFieldMapper.java +++ b/server/src/main/java/org/opensearch/index/mapper/NumberFieldMapper.java @@ -283,6 +283,17 @@ public byte[] encodePoint(Number value) { return point; } + @Override + public byte[] encodePoint(Object value, boolean roundUp) { + Float numericValue = parse(value, true); + if (roundUp) { + numericValue = HalfFloatPoint.nextUp(numericValue); + } else { + numericValue = HalfFloatPoint.nextDown(numericValue); + } + return encodePoint(numericValue); + } + @Override public double toDoubleValue(long value) { return HalfFloatPoint.sortableShortToHalfFloat((short) value); @@ -459,6 +470,17 @@ public byte[] encodePoint(Number value) { return point; } + @Override + public byte[] encodePoint(Object value, boolean roundUp) { + Float numericValue = parse(value, true); + if (roundUp) { + numericValue = FloatPoint.nextUp(numericValue); + } else { + numericValue = FloatPoint.nextDown(numericValue); + } + return encodePoint(numericValue); + } + @Override public double toDoubleValue(long value) { return NumericUtils.sortableIntToFloat((int) value); @@ -626,6 +648,17 @@ public byte[] encodePoint(Number value) { return point; } + @Override + public byte[] encodePoint(Object value, boolean roundUp) { + Double numericValue = parse(value, true); + if (roundUp) { + numericValue = DoublePoint.nextUp(numericValue); + } else { + numericValue = DoublePoint.nextDown(numericValue); + } + return encodePoint(numericValue); + } + @Override public double toDoubleValue(long value) { return NumericUtils.sortableLongToDouble(value); @@ -789,6 +822,23 @@ public byte[] encodePoint(Number value) { return point; } + @Override + public byte[] encodePoint(Object value, boolean roundUp) { + Byte numericValue = parse(value, true); + if (roundUp) { + // ASC: exclusive lower bound + if (numericValue < Byte.MAX_VALUE) { + numericValue = (byte) (numericValue + 1); + } + } else { + // DESC: exclusive upper bound + if (numericValue > Byte.MIN_VALUE) { + numericValue = (byte) (numericValue - 1); + } + } + return encodePoint(numericValue); + } + @Override public double toDoubleValue(long value) { return objectToDouble(value); @@ -873,6 +923,22 @@ public byte[] encodePoint(Number value) { return point; } + @Override + public byte[] encodePoint(Object value, boolean roundUp) { + Short numericValue = parse(value, true); + if (roundUp) { + // ASC: exclusive lower bound + if (numericValue < Short.MAX_VALUE) { + numericValue = (short) (numericValue + 1); + } + } else { + if (numericValue > Short.MIN_VALUE) { + numericValue = (short) (numericValue - 1); + } + } + return encodePoint(numericValue); + } + @Override public double toDoubleValue(long value) { return (double) value; @@ -953,6 +1019,23 @@ public byte[] encodePoint(Number value) { return point; } + @Override + public byte[] encodePoint(Object value, boolean roundUp) { + Integer numericValue = parse(value, true); + // Always apply exclusivity + if (roundUp) { + if (numericValue < Integer.MAX_VALUE) { + numericValue = numericValue + 1; + } + } else { + if (numericValue > Integer.MIN_VALUE) { + numericValue = numericValue - 1; + } + } + + return encodePoint(numericValue); + } + @Override public double toDoubleValue(long value) { return (double) value; @@ -1139,6 +1222,23 @@ public byte[] encodePoint(Number value) { return point; } + @Override + public byte[] encodePoint(Object value, boolean roundUp) { + Long numericValue = parse(value, true); + if (roundUp) { + // ASC: exclusive lower bound + if (numericValue < Long.MAX_VALUE) { + numericValue = numericValue + 1; + } + } else { + // DESC: exclusive upper bound + if (numericValue > Long.MIN_VALUE) { + numericValue = numericValue - 1; + } + } + return encodePoint(numericValue); + } + @Override public double toDoubleValue(long value) { return (double) value; @@ -1281,6 +1381,22 @@ public byte[] encodePoint(Number value) { return point; } + @Override + public byte[] encodePoint(Object value, boolean roundUp) { + BigInteger numericValue = parse(value, true); + if (roundUp) { + if (numericValue.compareTo(Numbers.MAX_UNSIGNED_LONG_VALUE) < 0) { + numericValue = numericValue.add(BigInteger.ONE); + } + } else { + // DESC: exclusive upper bound + if (numericValue.compareTo(Numbers.MIN_UNSIGNED_LONG_VALUE) > 0) { + numericValue = numericValue.subtract(BigInteger.ONE); + } + } + return encodePoint(numericValue); + } + @Override public double toDoubleValue(long value) { return Numbers.unsignedLongToDouble(value); @@ -1851,6 +1967,11 @@ public byte[] encodePoint(Number value) { return type.encodePoint(value); } + @Override + public byte[] encodePoint(Object value, boolean roundUp) { + return type.encodePoint(value, roundUp); + } + @Override public double toDoubleValue(long value) { return type.toDoubleValue(value); diff --git a/server/src/main/java/org/opensearch/index/mapper/NumericPointEncoder.java b/server/src/main/java/org/opensearch/index/mapper/NumericPointEncoder.java index be746a5526594..6dc66bede1e05 100644 --- a/server/src/main/java/org/opensearch/index/mapper/NumericPointEncoder.java +++ b/server/src/main/java/org/opensearch/index/mapper/NumericPointEncoder.java @@ -13,4 +13,12 @@ */ public interface NumericPointEncoder { byte[] encodePoint(Number value); + + /** + * Encodes an Object value to byte array for Approximation Framework search_after optimization. + * @param value the search_after value as Object + * @param roundUp whether to round up (for lower bounds) or down (for upper bounds) + * @return encoded byte array + */ + byte[] encodePoint(Object value, boolean roundUp); } diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java index adecb8c89ef82..046eb4dc1c86f 100644 --- a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java +++ b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java @@ -31,6 +31,8 @@ import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.DocIdSetBuilder; import org.apache.lucene.util.IntsRef; +import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.index.mapper.NumericPointEncoder; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.sort.FieldSortBuilder; import org.opensearch.search.sort.SortOrder; @@ -52,10 +54,9 @@ public class ApproximatePointRangeQuery extends ApproximateQuery { public static final Function UNSIGNED_LONG_FORMAT = bytes -> BigIntegerPoint.decodeDimension(bytes, 0).toString(); private int size; - private SortOrder sortOrder; - - public final PointRangeQuery pointRangeQuery; + public PointRangeQuery pointRangeQuery; + private final Function valueToString; public ApproximatePointRangeQuery( String field, @@ -78,6 +79,7 @@ protected ApproximatePointRangeQuery( ) { this.size = size; this.sortOrder = sortOrder; + this.valueToString = valueToString; this.pointRangeQuery = new PointRangeQuery(field, lowerPoint, upperPoint, numDims) { @Override protected String toString(int dimension, byte[] value) { @@ -114,12 +116,12 @@ public void visit(QueryVisitor visitor) { @Override public final ConstantScoreWeight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + final ArrayUtil.ByteArrayComparator comparator = ArrayUtil.getUnsignedComparator(pointRangeQuery.getBytesPerDim()); + Weight pointRangeQueryWeight = pointRangeQuery.createWeight(searcher, scoreMode, boost); return new ConstantScoreWeight(this, boost) { - private final ArrayUtil.ByteArrayComparator comparator = ArrayUtil.getUnsignedComparator(pointRangeQuery.getBytesPerDim()); - // we pull this from PointRangeQuery since it is final private boolean matches(byte[] packedValue) { for (int dim = 0; dim < pointRangeQuery.getNumDims(); dim++) { @@ -138,7 +140,6 @@ private boolean matches(byte[] packedValue) { // we pull this from PointRangeQuery since it is final private PointValues.Relation relate(byte[] minPackedValue, byte[] maxPackedValue) { - boolean crosses = false; for (int dim = 0; dim < pointRangeQuery.getNumDims(); dim++) { @@ -352,6 +353,7 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti if (checkValidPointValues(values) == false) { return null; } + // values.size(): total points indexed, In most cases: values.size() ≈ number of documents (assuming single-valued fields) if (size > values.size()) { return pointRangeQueryWeight.scorerSupplier(context); } else { @@ -423,6 +425,19 @@ public boolean isCacheable(LeafReaderContext ctx) { }; } + private byte[] computeEffectiveBound(SearchContext context, boolean isLowerBound) { + byte[] originalBound = isLowerBound ? pointRangeQuery.getLowerPoint() : pointRangeQuery.getUpperPoint(); + boolean isAscending = sortOrder == null || sortOrder.equals(SortOrder.ASC); + if ((isLowerBound && isAscending) || (isLowerBound == false && isAscending == false)) { + Object searchAfterValue = context.request().source().searchAfter()[0]; + MappedFieldType fieldType = context.getQueryShardContext().fieldMapper(pointRangeQuery.getField()); + if (fieldType instanceof NumericPointEncoder encoder) { + return encoder.encodePoint(searchAfterValue, isLowerBound); + } + } + return originalBound; + } + @Override public boolean canApproximate(SearchContext context) { if (context == null) { @@ -435,7 +450,6 @@ public boolean canApproximate(SearchContext context) { if (context.trackTotalHitsUpTo() == SearchContext.TRACK_TOTAL_HITS_ACCURATE) { return false; } - // size 0 could be set for caching if (context.from() + context.size() == 0) { this.setSize(SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO); @@ -459,12 +473,24 @@ public boolean canApproximate(SearchContext context) { // Cannot sort documents missing this field. return false; } + this.setSortOrder(primarySortField.order()); if (context.request().source().searchAfter() != null) { - // TODO: We *could* optimize searchAfter, especially when this is the only sort field, but existing pruning is pretty - // good. - return false; + byte[] lower; + byte[] upper; + if (sortOrder == SortOrder.ASC) { + lower = computeEffectiveBound(context, true); + upper = pointRangeQuery.getUpperPoint(); + } else { + lower = pointRangeQuery.getLowerPoint(); + upper = computeEffectiveBound(context, false); + } + this.pointRangeQuery = new PointRangeQuery(pointRangeQuery.getField(), lower, upper, pointRangeQuery.getNumDims()) { + @Override + protected String toString(int dimension, byte[] value) { + return valueToString.apply(value); + } + }; } - this.setSortOrder(primarySortField.order()); } return context.request().source().terminateAfter() == SearchContext.DEFAULT_TERMINATE_AFTER; } diff --git a/server/src/test/java/org/opensearch/index/mapper/DateFieldMapperTests.java b/server/src/test/java/org/opensearch/index/mapper/DateFieldMapperTests.java index 8e4bb86bb58e7..f099a8112b179 100644 --- a/server/src/test/java/org/opensearch/index/mapper/DateFieldMapperTests.java +++ b/server/src/test/java/org/opensearch/index/mapper/DateFieldMapperTests.java @@ -32,6 +32,7 @@ package org.opensearch.index.mapper; +import org.apache.lucene.document.LongPoint; import org.apache.lucene.index.DocValuesSkipIndexType; import org.apache.lucene.index.DocValuesType; import org.apache.lucene.index.FieldInfo; @@ -618,4 +619,28 @@ public void testDeriveSource_NoValue() throws IOException { String source = builder.toString(); assertEquals("{}", source); } + + public void testDateEncodePoint() { + DateFieldMapper.DateFieldType fieldType = new DateFieldMapper.DateFieldType("test_field"); + // Test basic roundUp + long baseTime = fieldType.parse("2024-01-15T10:30:00Z"); + byte[] encoded = fieldType.encodePoint("2024-01-15T10:30:00Z", true); + long decoded = LongPoint.decodeDimension(encoded, 0); + assertEquals(baseTime + 1, decoded); + // Test basic roundDown + encoded = fieldType.encodePoint("2024-01-15T10:30:00Z", false); + decoded = LongPoint.decodeDimension(encoded, 0); + assertEquals(baseTime - 1, decoded); + // Test with extreme long values, + long largeEpoch = 253402300799999L; + String largeEpochStr = String.valueOf(largeEpoch); + encoded = fieldType.encodePoint(largeEpochStr, true); + decoded = LongPoint.decodeDimension(encoded, 0); + assertEquals("Should increment epoch millis", largeEpoch + 1, decoded); + long negativeEpoch = -377705116800000L; + String negativeEpochStr = String.valueOf(negativeEpoch); + encoded = fieldType.encodePoint(negativeEpochStr, false); + decoded = LongPoint.decodeDimension(encoded, 0); + assertEquals("Should decrement epoch millis", negativeEpoch - 1, decoded); + } } diff --git a/server/src/test/java/org/opensearch/index/mapper/NumberFieldMapperTests.java b/server/src/test/java/org/opensearch/index/mapper/NumberFieldMapperTests.java index 92cc42cec96fb..86153c4e4bb52 100644 --- a/server/src/test/java/org/opensearch/index/mapper/NumberFieldMapperTests.java +++ b/server/src/test/java/org/opensearch/index/mapper/NumberFieldMapperTests.java @@ -33,6 +33,10 @@ package org.opensearch.index.mapper; import org.apache.lucene.document.Document; +import org.apache.lucene.document.DoublePoint; +import org.apache.lucene.document.FloatPoint; +import org.apache.lucene.document.IntPoint; +import org.apache.lucene.document.LongPoint; import org.apache.lucene.document.SortedNumericDocValuesField; import org.apache.lucene.document.StoredField; import org.apache.lucene.index.DirectoryReader; @@ -40,9 +44,11 @@ import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.IndexableField; +import org.apache.lucene.sandbox.document.BigIntegerPoint; import org.apache.lucene.sandbox.document.HalfFloatPoint; import org.apache.lucene.store.Directory; import org.apache.lucene.util.NumericUtils; +import org.opensearch.common.Numbers; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.xcontent.MediaTypeRegistry; @@ -547,4 +553,170 @@ private Document createDocument(NumberFieldMapper.NumberType type, List } return doc; } + + public void testHalfFloatEncodePoint() { + NumberFieldMapper.NumberType type = NumberFieldMapper.NumberType.HALF_FLOAT; + // Test roundUp = true + byte[] encoded = type.encodePoint(100.5f, true); + float decoded = HalfFloatPoint.decodeDimension(encoded, 0); + assertTrue("Should round up", decoded > 100.5f); + // Test roundUp = false + encoded = type.encodePoint(100.5f, false); + decoded = HalfFloatPoint.decodeDimension(encoded, 0); + assertTrue("Should round down", decoded < 100.5f); + encoded = type.encodePoint(0.0f, true); + decoded = HalfFloatPoint.decodeDimension(encoded, 0); + assertTrue("Zero roundUp should be positive", decoded > 0.0f); + encoded = type.encodePoint("123.45", true); + decoded = HalfFloatPoint.decodeDimension(encoded, 0); + assertTrue("String parsing should work", decoded > 123.45f); + } + + public void testFloatEncodePoint() { + NumberFieldMapper.NumberType type = NumberFieldMapper.NumberType.FLOAT; + // Test roundUp = true + byte[] encoded = type.encodePoint(100.5f, true); + float decoded = FloatPoint.decodeDimension(encoded, 0); + assertEquals(FloatPoint.nextUp(100.5f), decoded, 0.0f); + // Test roundUp = false + encoded = type.encodePoint(100.5f, false); + decoded = FloatPoint.decodeDimension(encoded, 0); + assertEquals(FloatPoint.nextDown(100.5f), decoded, 0.0f); + } + + public void testDoubleEncodePoint() { + NumberFieldMapper.NumberType type = NumberFieldMapper.NumberType.DOUBLE; + // Test roundUp = true + byte[] encoded = type.encodePoint(100.5, true); + double decoded = DoublePoint.decodeDimension(encoded, 0); + assertEquals(DoublePoint.nextUp(100.5), decoded, 0.0); + // Test roundUp = false + encoded = type.encodePoint(100.5, false); + decoded = DoublePoint.decodeDimension(encoded, 0); + assertEquals(DoublePoint.nextDown(100.5), decoded, 0.0); + encoded = type.encodePoint("123.456789", true); + decoded = DoublePoint.decodeDimension(encoded, 0); + assertTrue("String parsing should work", decoded > 123.456789); + } + + public void testIntegerEncodePoint() { + NumberFieldMapper.NumberType type = NumberFieldMapper.NumberType.INTEGER; + // Test roundUp = true + byte[] encoded = type.encodePoint(100, true); + int decoded = IntPoint.decodeDimension(encoded, 0); + assertEquals(101, decoded); + // Test roundUp = false + encoded = type.encodePoint(100, false); + decoded = IntPoint.decodeDimension(encoded, 0); + assertEquals(99, decoded); + encoded = type.encodePoint(Integer.MAX_VALUE, true); + decoded = IntPoint.decodeDimension(encoded, 0); + assertEquals(Integer.MAX_VALUE, decoded); // Can't increment + encoded = type.encodePoint(Integer.MIN_VALUE, false); + decoded = IntPoint.decodeDimension(encoded, 0); + assertEquals(Integer.MIN_VALUE, decoded); // Can't decrement + encoded = type.encodePoint(100.7f, true); + decoded = IntPoint.decodeDimension(encoded, 0); + assertEquals(101, decoded); // 100.7 coerced to 100, then incremented + } + + public void testLongEncodePoint() { + NumberFieldMapper.NumberType type = NumberFieldMapper.NumberType.LONG; + // Test roundUp = true + byte[] encoded = type.encodePoint(100L, true); + long decoded = LongPoint.decodeDimension(encoded, 0); + assertEquals(101L, decoded); + // Test roundUp = false + encoded = type.encodePoint(100L, false); + decoded = LongPoint.decodeDimension(encoded, 0); + assertEquals(99L, decoded); + // Test edge cases + encoded = type.encodePoint(Long.MAX_VALUE, true); + decoded = LongPoint.decodeDimension(encoded, 0); + assertEquals(Long.MAX_VALUE, decoded); // Can't increment + encoded = type.encodePoint("9223372036854775806", true); + decoded = LongPoint.decodeDimension(encoded, 0); + assertEquals(9223372036854775807L, decoded); + } + + public void testByteEncodePoint() { + NumberFieldMapper.NumberType type = NumberFieldMapper.NumberType.BYTE; + // Test roundUp = true + byte[] encoded = type.encodePoint((byte) 100, true); + int decoded = IntPoint.decodeDimension(encoded, 0); + assertEquals(101, decoded); + // Test roundUp = false + encoded = type.encodePoint((byte) 100, false); + decoded = IntPoint.decodeDimension(encoded, 0); + assertEquals(99, decoded); + // Test edge cases + encoded = type.encodePoint(Byte.MAX_VALUE, true); + decoded = IntPoint.decodeDimension(encoded, 0); + assertEquals(Byte.MAX_VALUE, decoded); + encoded = type.encodePoint(Byte.MIN_VALUE, false); + decoded = IntPoint.decodeDimension(encoded, 0); + assertEquals(Byte.MIN_VALUE, decoded); + } + + public void testShortEncodePoint() { + NumberFieldMapper.NumberType type = NumberFieldMapper.NumberType.SHORT; + // Test roundUp = true + byte[] encoded = type.encodePoint((short) 100, true); + int decoded = IntPoint.decodeDimension(encoded, 0); + assertEquals(101, decoded); + // Test roundUp = false + encoded = type.encodePoint((short) 100, false); + decoded = IntPoint.decodeDimension(encoded, 0); + assertEquals(99, decoded); + // Test edge cases + encoded = type.encodePoint(Short.MAX_VALUE, true); + decoded = IntPoint.decodeDimension(encoded, 0); + assertEquals(Short.MAX_VALUE, decoded); + } + + public void testUnsignedLongEncodePoint() { + NumberFieldMapper.NumberType type = NumberFieldMapper.NumberType.UNSIGNED_LONG; + // Test roundUp = true + byte[] encoded = type.encodePoint(BigInteger.valueOf(100L), true); + BigInteger decoded = BigIntegerPoint.decodeDimension(encoded, 0); + assertEquals(BigInteger.valueOf(101L), decoded); + // Test roundUp = false + encoded = type.encodePoint(BigInteger.valueOf(100L), false); + decoded = BigIntegerPoint.decodeDimension(encoded, 0); + assertEquals(BigInteger.valueOf(99L), decoded); + // Test edge cases + BigInteger maxUnsignedLong = Numbers.MAX_UNSIGNED_LONG_VALUE; + encoded = type.encodePoint(maxUnsignedLong, true); + decoded = BigIntegerPoint.decodeDimension(encoded, 0); + assertEquals(maxUnsignedLong, decoded); // Can't increment + encoded = type.encodePoint("18446744073709551614", true); + decoded = BigIntegerPoint.decodeDimension(encoded, 0); + assertEquals(new BigInteger("18446744073709551615"), decoded); + } + + public void testCoercionBehavior() { + // Test that decimal values are properly coerced for integer types + NumberFieldMapper.NumberType type = NumberFieldMapper.NumberType.LONG; + // 100.7 should be coerced to 100, then incremented to 101 + byte[] encoded = type.encodePoint(100.7, true); + long decoded = LongPoint.decodeDimension(encoded, 0); + assertEquals(101L, decoded); + // 100.3 should be coerced to 100, then decremented to 99 + encoded = type.encodePoint(100.3, false); + decoded = LongPoint.decodeDimension(encoded, 0); + assertEquals(99L, decoded); + } + + public void testNegativeNumberHandling() { + // Test negative numbers for integer types + NumberFieldMapper.NumberType type = NumberFieldMapper.NumberType.INTEGER; + // Negative number roundUp (exclusive lower bound) + byte[] encoded = type.encodePoint(-100, true); + int decoded = IntPoint.decodeDimension(encoded, 0); + assertEquals(-99, decoded); + // Negative number roundDown (exclusive upper bound) + encoded = type.encodePoint(-100, false); + decoded = IntPoint.decodeDimension(encoded, 0); + assertEquals(-101, decoded); + } } diff --git a/server/src/test/java/org/opensearch/search/approximate/ApproximatePointRangeQueryTests.java b/server/src/test/java/org/opensearch/search/approximate/ApproximatePointRangeQueryTests.java index e7ef69b2ad8c6..1f70699d0cde2 100644 --- a/server/src/test/java/org/opensearch/search/approximate/ApproximatePointRangeQueryTests.java +++ b/server/src/test/java/org/opensearch/search/approximate/ApproximatePointRangeQueryTests.java @@ -21,8 +21,10 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.sandbox.document.BigIntegerPoint; import org.apache.lucene.sandbox.document.HalfFloatPoint; +import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.Sort; import org.apache.lucene.search.SortField; import org.apache.lucene.search.TopDocs; @@ -33,6 +35,8 @@ import org.opensearch.common.time.DateFormatter; import org.opensearch.common.time.DateMathParser; import org.opensearch.index.mapper.DateFieldMapper.DateFieldType; +import org.opensearch.index.mapper.NumberFieldMapper; +import org.opensearch.index.query.QueryShardContext; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.internal.ShardSearchRequest; @@ -97,6 +101,11 @@ Query rangeQuery(String fieldName, Number lower, Number upper) { SortField.Type getSortFieldType() { return SortField.Type.INT; } + + @Override + NumberFieldMapper.NumberType getNumberType() { + return NumberFieldMapper.NumberType.INTEGER; + } }, LONG("long_field", ApproximatePointRangeQuery.LONG_FORMAT) { @Override @@ -125,6 +134,11 @@ Query rangeQuery(String fieldName, Number lower, Number upper) { SortField.Type getSortFieldType() { return SortField.Type.LONG; } + + @Override + NumberFieldMapper.NumberType getNumberType() { + return NumberFieldMapper.NumberType.LONG; + } }, HALF_FLOAT("half_float_field", ApproximatePointRangeQuery.HALF_FLOAT_FORMAT) { @Override @@ -158,6 +172,11 @@ Query rangeQuery(String fieldName, Number lower, Number upper) { SortField.Type getSortFieldType() { return SortField.Type.LONG; } + + @Override + NumberFieldMapper.NumberType getNumberType() { + return NumberFieldMapper.NumberType.HALF_FLOAT; + } }, FLOAT("float_field", ApproximatePointRangeQuery.FLOAT_FORMAT) { @Override @@ -184,6 +203,11 @@ Query rangeQuery(String fieldName, Number lower, Number upper) { SortField.Type getSortFieldType() { return SortField.Type.FLOAT; } + + @Override + NumberFieldMapper.NumberType getNumberType() { + return NumberFieldMapper.NumberType.FLOAT; + } }, DOUBLE("double_field", ApproximatePointRangeQuery.DOUBLE_FORMAT) { @Override @@ -210,6 +234,11 @@ Query rangeQuery(String fieldName, Number lower, Number upper) { SortField.Type getSortFieldType() { return SortField.Type.DOUBLE; } + + @Override + NumberFieldMapper.NumberType getNumberType() { + return NumberFieldMapper.NumberType.DOUBLE; + } }, UNSIGNED_LONG("unsigned_long_field", ApproximatePointRangeQuery.UNSIGNED_LONG_FORMAT) { @Override @@ -247,6 +276,11 @@ SortField.Type getSortFieldType() { String getSortFieldName() { return fieldName + "_sort"; } + + @Override + NumberFieldMapper.NumberType getNumberType() { + return NumberFieldMapper.NumberType.UNSIGNED_LONG; + } }; final String fieldName; @@ -267,6 +301,8 @@ String getSortFieldName() { abstract SortField.Type getSortFieldType(); + abstract NumberFieldMapper.NumberType getNumberType(); + String getSortFieldName() { return fieldName; } @@ -1008,7 +1044,7 @@ public void testDateRangeIncludingNowQueryApproximation() throws IOException { } } - private void testApproximateVsExactQueryWithDateField( + public void testApproximateVsExactQueryWithDateField( IndexSearcher searcher, DateFieldType dateFieldType, String lowerBound, @@ -1188,4 +1224,120 @@ public void testApproximateWithSort() { assertTrue("Should approximate with single sort on same field", query.canApproximate(mockContext)); } } + + public void testApproximateRangeWithSearchAfterAsc() throws IOException { + testApproximateRangeWithSearchAfter(SortOrder.ASC); + } + + public void testApproximateRangeWithSearchAfterDesc() throws IOException { + testApproximateRangeWithSearchAfter(SortOrder.DESC); + } + + private void testApproximateRangeWithSearchAfter(SortOrder sortOrder) throws IOException { + if (numericType == NumericType.HALF_FLOAT) { + // Skip - HALF_FLOAT uses different fields for storage vs sorting which causes issues with search_after boundary checking during + // tests + return; + } + try (Directory directory = newDirectory()) { + try (RandomIndexWriter iw = new RandomIndexWriter(random(), directory, new WhitespaceAnalyzer())) { + int dims = 1; + int numPoints = RandomNumbers.randomIntBetween(random(), 2000, 5000); + for (int i = 0; i < numPoints; i++) { + Document doc = new Document(); + numericType.addField(doc, numericType.fieldName, i); + numericType.addDocValuesField(doc, numericType.fieldName, i); + iw.addDocument(doc); + if (random().nextInt(20) == 0) { + iw.flush(); + } + } + iw.flush(); + if (random().nextBoolean()) { + iw.forceMerge(1); + } + try (IndexReader reader = iw.getReader()) { + IndexSearcher searcher = new IndexSearcher(reader); + long lower = RandomNumbers.randomLongBetween(random(), 0, numPoints / 4); + long upper = RandomNumbers.randomLongBetween(random(), 3 * numPoints / 4, numPoints - 1); + int size = RandomNumbers.randomIntBetween(random(), 10, 50); + long searchAfterValue = RandomNumbers.randomLongBetween(random(), lower, upper - size); + // First, get a document at searchAfterValue to use as the searchAfter point + Query exactValueQuery = numericType.rangeQuery(numericType.fieldName, searchAfterValue, searchAfterValue); + boolean reverseSort = sortOrder == SortOrder.DESC; + Sort sort = new Sort(new SortField(numericType.getSortFieldName(), numericType.getSortFieldType(), reverseSort)); + TopDocs searchAfterDocs = searcher.search(exactValueQuery, 1, sort); + FieldDoc searchAfterDoc = (FieldDoc) searchAfterDocs.scoreDocs[0]; + // Create mock context for approximate query + SearchContext mockContext = mock(SearchContext.class); + ShardSearchRequest mockRequest = mock(ShardSearchRequest.class); + SearchSourceBuilder source = new SearchSourceBuilder(); + source.sort(new FieldSortBuilder(numericType.fieldName).order(sortOrder)); + source.searchAfter(searchAfterDoc.fields); + when(mockContext.aggregations()).thenReturn(null); + when(mockContext.trackTotalHitsUpTo()).thenReturn(10000); + when(mockContext.from()).thenReturn(0); + when(mockContext.size()).thenReturn(size); + when(mockContext.request()).thenReturn(mockRequest); + when(mockRequest.source()).thenReturn(source); + NumberFieldMapper.NumberFieldType fieldType = new NumberFieldMapper.NumberFieldType( + numericType.fieldName, + numericType.getNumberType() + ); + QueryShardContext queryShardContext = mock(QueryShardContext.class); + when(queryShardContext.fieldMapper(numericType.fieldName)).thenReturn(fieldType); + when(mockContext.getQueryShardContext()).thenReturn(queryShardContext); + // Test approximate query with searchAfter + ApproximatePointRangeQuery approxQuery = new ApproximatePointRangeQuery( + numericType.fieldName, + numericType.encode(lower), + numericType.encode(upper), + dims, + size, + sortOrder, + numericType.format + ); + assertTrue("Should be able to approximate", approxQuery.canApproximate(mockContext)); + TopDocs approxDocs = searcher.search(approxQuery, size, sort); + // Compare with exact query using Lucene's searchAfter + Query exactQuery = numericType.rangeQuery(numericType.fieldName, lower, upper); + TopDocs exactDocs = searcher.searchAfter(searchAfterDoc, exactQuery, size, sort); + // Verify results match + assertEquals( + "Approximate and exact queries should return same number of docs", + exactDocs.scoreDocs.length, + approxDocs.scoreDocs.length + ); + for (int i = 0; i < Math.min(approxDocs.scoreDocs.length, exactDocs.scoreDocs.length); i++) { + FieldDoc approxFieldDoc = (FieldDoc) approxDocs.scoreDocs[i]; + FieldDoc exactFieldDoc = (FieldDoc) exactDocs.scoreDocs[i]; + assertEquals("Doc at position " + i + " should match", exactFieldDoc.doc, approxFieldDoc.doc); + assertEquals( + "Sort value at position " + i + " should match", + (exactFieldDoc.fields[0]), + (approxFieldDoc.fields[0]) + ); + } + // Verify all returned docs are correctly ordered relative to searchAfterValue + for (ScoreDoc scoreDoc : approxDocs.scoreDocs) { + FieldDoc fieldDoc = (FieldDoc) scoreDoc; + Number value = (Number) fieldDoc.fields[0]; + long searchAfterLong = ((Number) searchAfterDoc.fields[0]).longValue(); + + if (sortOrder == SortOrder.ASC) { + assertTrue( + "Doc value " + value + " should be > searchAfterValue " + searchAfterLong, + value.longValue() > searchAfterLong + ); + } else { + assertTrue( + "Doc value " + value + " should be < searchAfterValue " + searchAfterLong, + value.longValue() < searchAfterLong + ); + } + } + } + } + } + } }