diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregationBuilder.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregationBuilder.java index a7516a6fd6b24..47e1abfe71f8e 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregationBuilder.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregationBuilder.java @@ -68,6 +68,7 @@ public final class CardinalityAggregationBuilder extends ValuesSourceAggregation private static final ParseField REHASH = new ParseField("rehash").withAllDeprecated("no replacement - values will always be rehashed"); public static final ParseField PRECISION_THRESHOLD_FIELD = new ParseField("precision_threshold"); + public static final ParseField EXECUTION_HINT_FIELD = new ParseField(("execution_hint")); public static final ObjectParser PARSER = ObjectParser.fromBuilder( NAME, @@ -76,6 +77,7 @@ public final class CardinalityAggregationBuilder extends ValuesSourceAggregation static { ValuesSourceAggregationBuilder.declareFields(PARSER, true, false, false); PARSER.declareLong(CardinalityAggregationBuilder::precisionThreshold, CardinalityAggregationBuilder.PRECISION_THRESHOLD_FIELD); + PARSER.declareString(CardinalityAggregationBuilder::executionHint, CardinalityAggregationBuilder.EXECUTION_HINT_FIELD); PARSER.declareLong((b, v) -> {/*ignore*/}, REHASH); } @@ -85,6 +87,8 @@ public static void registerAggregators(ValuesSourceRegistry.Builder builder) { private Long precisionThreshold = null; + private String executionHint = null; + public CardinalityAggregationBuilder(String name) { super(name); } @@ -96,6 +100,7 @@ public CardinalityAggregationBuilder( ) { super(clone, factoriesBuilder, metadata); this.precisionThreshold = clone.precisionThreshold; + this.executionHint = clone.executionHint; } @Override @@ -111,6 +116,7 @@ public CardinalityAggregationBuilder(StreamInput in) throws IOException { if (in.readBoolean()) { precisionThreshold = in.readLong(); } + executionHint = in.readOptionalString(); } @Override @@ -125,6 +131,7 @@ protected void innerWriteTo(StreamOutput out) throws IOException { if (hasPrecisionThreshold) { out.writeLong(precisionThreshold); } + out.writeOptionalString(executionHint); } @Override @@ -155,6 +162,15 @@ public Long precisionThreshold() { return precisionThreshold; } + public CardinalityAggregationBuilder executionHint(String executionHint) { + this.executionHint = executionHint; + return this; + } + + public String executionHint() { + return executionHint; + } + @Override protected CardinalityAggregatorFactory innerBuild( QueryShardContext queryShardContext, @@ -162,7 +178,16 @@ protected CardinalityAggregatorFactory innerBuild( AggregatorFactory parent, AggregatorFactories.Builder subFactoriesBuilder ) throws IOException { - return new CardinalityAggregatorFactory(name, config, precisionThreshold, queryShardContext, parent, subFactoriesBuilder, metadata); + return new CardinalityAggregatorFactory( + name, + config, + precisionThreshold, + executionHint, + queryShardContext, + parent, + subFactoriesBuilder, + metadata + ); } @Override @@ -170,12 +195,15 @@ public XContentBuilder doXContentBody(XContentBuilder builder, Params params) th if (precisionThreshold != null) { builder.field(PRECISION_THRESHOLD_FIELD.getPreferredName(), precisionThreshold); } + if (executionHint != null) { + builder.field(EXECUTION_HINT_FIELD.getPreferredName(), executionHint); + } return builder; } @Override public int hashCode() { - return Objects.hash(super.hashCode(), precisionThreshold); + return Objects.hash(super.hashCode(), precisionThreshold, executionHint); } @Override @@ -184,7 +212,7 @@ public boolean equals(Object obj) { if (obj == null || getClass() != obj.getClass()) return false; if (super.equals(obj) == false) return false; CardinalityAggregationBuilder other = (CardinalityAggregationBuilder) obj; - return Objects.equals(precisionThreshold, other.precisionThreshold); + return Objects.equals(precisionThreshold, other.precisionThreshold) && Objects.equals(executionHint, other.executionHint); } @Override diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java index d578c37af8818..99b73b68f0163 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java @@ -89,6 +89,7 @@ public class CardinalityAggregator extends NumericMetricsAggregator.SingleValue private static final Logger logger = LogManager.getLogger(CardinalityAggregator.class); + private final CardinalityAggregatorFactory.ExecutionMode executionMode; private final int precision; private final ValuesSource valuesSource; @@ -111,6 +112,7 @@ public CardinalityAggregator( String name, ValuesSourceConfig valuesSourceConfig, int precision, + CardinalityAggregatorFactory.ExecutionMode executionMode, SearchContext context, Aggregator parent, Map metadata @@ -121,6 +123,7 @@ public CardinalityAggregator( this.precision = precision; this.counts = valuesSource == null ? null : new HyperLogLogPlusPlus(precision, context.bigArrays(), 1); this.valuesSourceConfig = valuesSourceConfig; + this.executionMode = executionMode; } @Override @@ -151,6 +154,9 @@ private Collector pickCollector(LeafReaderContext ctx) throws IOException { if (maxOrd == 0) { emptyCollectorsUsed++; return new EmptyCollector(); + } else if (executionMode == CardinalityAggregatorFactory.ExecutionMode.ORDINALS) { // Force OrdinalsCollector + ordinalsCollectorsUsed++; + collector = new OrdinalsCollector(counts, ordinalValues, context.bigArrays()); } else { final long ordinalsMemoryUsage = OrdinalsCollector.memoryOverhead(maxOrd); final long countsMemoryUsage = HyperLogLogPlusPlus.memoryUsage(precision); @@ -480,7 +486,7 @@ public void close() { * * @opensearch.internal */ - private static class DirectCollector extends Collector { + public static class DirectCollector extends Collector { private final MurmurHash3Values hashes; private final HyperLogLogPlusPlus counts; @@ -517,7 +523,7 @@ public void close() { * * @opensearch.internal */ - private static class OrdinalsCollector extends Collector { + public static class OrdinalsCollector extends Collector { private static final long SHALLOW_FIXEDBITSET_SIZE = RamUsageEstimator.shallowSizeOfInstance(FixedBitSet.class); diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorFactory.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorFactory.java index 980667b45324e..c70ec7c645e63 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorFactory.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorFactory.java @@ -44,6 +44,7 @@ import org.opensearch.search.internal.SearchContext; import java.io.IOException; +import java.util.Locale; import java.util.Map; /** @@ -53,12 +54,45 @@ */ class CardinalityAggregatorFactory extends ValuesSourceAggregatorFactory { + /** + * Execution mode for cardinality agg + * + * @opensearch.internal + */ + public enum ExecutionMode { + + UNSET, + DIRECT, + ORDINALS; + + ExecutionMode() {} + + public static ExecutionMode fromString(String value) { + if (value == null) { + return UNSET; + } + try { + return ExecutionMode.valueOf(value.toUpperCase(Locale.ROOT)); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Unknown `execution_hint`: [" + value + "], expected any of [direct, ordinals]"); + } + } + + @Override + public String toString() { + return this.name().toLowerCase(Locale.ROOT); + } + } + + private final ExecutionMode executionMode; + private final Long precisionThreshold; CardinalityAggregatorFactory( String name, ValuesSourceConfig config, Long precisionThreshold, + String executionHint, QueryShardContext queryShardContext, AggregatorFactory parent, AggregatorFactories.Builder subFactoriesBuilder, @@ -66,6 +100,7 @@ class CardinalityAggregatorFactory extends ValuesSourceAggregatorFactory { ) throws IOException { super(name, config, queryShardContext, parent, subFactoriesBuilder, metadata); this.precisionThreshold = precisionThreshold; + this.executionMode = ExecutionMode.fromString(executionHint); } public static void registerAggregators(ValuesSourceRegistry.Builder builder) { @@ -74,7 +109,7 @@ public static void registerAggregators(ValuesSourceRegistry.Builder builder) { @Override protected Aggregator createUnmapped(SearchContext searchContext, Aggregator parent, Map metadata) throws IOException { - return new CardinalityAggregator(name, config, precision(), searchContext, parent, metadata); + return new CardinalityAggregator(name, config, precision(), executionMode, searchContext, parent, metadata); } @Override @@ -86,7 +121,7 @@ protected Aggregator doCreateInternal( ) throws IOException { return queryShardContext.getValuesSourceRegistry() .getAggregator(CardinalityAggregationBuilder.REGISTRY_KEY, config) - .build(name, config, precision(), searchContext, parent, metadata); + .build(name, config, precision(), executionMode, searchContext, parent, metadata); } @Override diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorSupplier.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorSupplier.java index d5cb0242762fd..b98ffbcc1b8e3 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorSupplier.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorSupplier.java @@ -49,6 +49,7 @@ Aggregator build( String name, ValuesSourceConfig valuesSourceConfig, int precision, + CardinalityAggregatorFactory.ExecutionMode executionMode, SearchContext context, Aggregator parent, Map metadata diff --git a/server/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorTests.java b/server/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorTests.java index 060e06f7336b3..5ad0f93ca753a 100644 --- a/server/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorTests.java +++ b/server/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorTests.java @@ -37,6 +37,7 @@ import org.apache.lucene.document.IntPoint; import org.apache.lucene.document.KeywordField; import org.apache.lucene.document.NumericDocValuesField; +import org.apache.lucene.document.SortedDocValuesField; import org.apache.lucene.document.SortedNumericDocValuesField; import org.apache.lucene.document.SortedSetDocValuesField; import org.apache.lucene.index.DirectoryReader; @@ -66,6 +67,7 @@ import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.search.aggregations.AggregatorTestCase; import org.opensearch.search.aggregations.InternalAggregation; +import org.opensearch.search.aggregations.LeafBucketCollector; import org.opensearch.search.aggregations.MultiBucketConsumerService; import org.opensearch.search.aggregations.pipeline.PipelineAggregator; import org.opensearch.search.aggregations.support.AggregationInspectionHelper; @@ -497,4 +499,139 @@ protected CountingAggregator createCountingAggregator( ) ); } + + private void testAggregationExecutionHint( + AggregationBuilder aggregationBuilder, + Query query, + CheckedConsumer buildIndex, + Consumer verify, + Consumer verifyCollector, + MappedFieldType fieldType + ) throws IOException { + try (Directory directory = newDirectory()) { + RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory); + buildIndex.accept(indexWriter); + indexWriter.close(); + + try (IndexReader indexReader = DirectoryReader.open(directory)) { + IndexSearcher indexSearcher = newSearcher(indexReader, true, true); + + CountingAggregator aggregator = new CountingAggregator( + new AtomicInteger(), + createAggregator(aggregationBuilder, indexSearcher, fieldType) + ); + aggregator.preCollection(); + indexSearcher.search(query, aggregator); + aggregator.postCollection(); + + MultiBucketConsumerService.MultiBucketConsumer reduceBucketConsumer = new MultiBucketConsumerService.MultiBucketConsumer( + Integer.MAX_VALUE, + new NoneCircuitBreakerService().getBreaker(CircuitBreaker.REQUEST) + ); + InternalAggregation.ReduceContext context = InternalAggregation.ReduceContext.forFinalReduction( + aggregator.context().bigArrays(), + getMockScriptService(), + reduceBucketConsumer, + PipelineAggregator.PipelineTree.EMPTY + ); + InternalCardinality topLevel = (InternalCardinality) aggregator.buildTopLevel(); + InternalCardinality card = (InternalCardinality) topLevel.reduce(Collections.singletonList(topLevel), context); + doAssertReducedMultiBucketConsumer(card, reduceBucketConsumer); + + verify.accept(card); + verifyCollector.accept(aggregator.getSelectedCollector()); + } + } + } + + public void testInvalidExecutionHint() throws IOException { + MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType("number", NumberFieldMapper.NumberType.LONG); + final CardinalityAggregationBuilder aggregationBuilder = new CardinalityAggregationBuilder("_name").field("number") + .executionHint("invalid"); + assertThrows(IllegalArgumentException.class, () -> testAggregationExecutionHint(aggregationBuilder, new MatchAllDocsQuery(), iw -> { + iw.addDocument(singleton(new NumericDocValuesField("number", 7))); + iw.addDocument(singleton(new NumericDocValuesField("number", 8))); + iw.addDocument(singleton(new NumericDocValuesField("number", 9))); + }, card -> { + assertEquals(3, card.getValue(), 0); + assertTrue(AggregationInspectionHelper.hasValue(card)); + }, collector -> { assertTrue(collector instanceof CardinalityAggregator.DirectCollector); }, fieldType)); + } + + public void testNoExecutionHintWithNumericDocValues() throws IOException { + MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType("number", NumberFieldMapper.NumberType.LONG); + final CardinalityAggregationBuilder aggregationBuilder = new CardinalityAggregationBuilder("_name").field("number"); + testAggregationExecutionHint(aggregationBuilder, new MatchAllDocsQuery(), iw -> { + iw.addDocument(singleton(new NumericDocValuesField("number", 7))); + iw.addDocument(singleton(new NumericDocValuesField("number", 8))); + iw.addDocument(singleton(new NumericDocValuesField("number", 9))); + }, card -> { + assertEquals(3, card.getValue(), 0); + assertTrue(AggregationInspectionHelper.hasValue(card)); + }, collector -> { assertTrue(collector instanceof CardinalityAggregator.DirectCollector); }, fieldType); + } + + public void testDirectExecutionHintWithNumericDocValues() throws IOException { + MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType("number", NumberFieldMapper.NumberType.LONG); + final CardinalityAggregationBuilder aggregationBuilder = new CardinalityAggregationBuilder("_name").field("number") + .executionHint("direct"); + testAggregationExecutionHint(aggregationBuilder, new MatchAllDocsQuery(), iw -> { + iw.addDocument(singleton(new NumericDocValuesField("number", 7))); + iw.addDocument(singleton(new NumericDocValuesField("number", 8))); + iw.addDocument(singleton(new NumericDocValuesField("number", 9))); + }, card -> { + assertEquals(3, card.getValue(), 0); + assertTrue(AggregationInspectionHelper.hasValue(card)); + }, collector -> { assertTrue(collector instanceof CardinalityAggregator.DirectCollector); }, fieldType); + } + + public void testOrdinalsExecutionHintWithNumericDocValues() throws IOException { + MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType("number", NumberFieldMapper.NumberType.LONG); + final CardinalityAggregationBuilder aggregationBuilder = new CardinalityAggregationBuilder("_name").field("number") + .executionHint("ordinals"); + testAggregationExecutionHint(aggregationBuilder, new MatchAllDocsQuery(), iw -> { + iw.addDocument(singleton(new NumericDocValuesField("number", 7))); + iw.addDocument(singleton(new NumericDocValuesField("number", 8))); + iw.addDocument(singleton(new NumericDocValuesField("number", 9))); + }, card -> { + assertEquals(3, card.getValue(), 0); + assertTrue(AggregationInspectionHelper.hasValue(card)); + }, collector -> { assertTrue(collector instanceof CardinalityAggregator.DirectCollector); }, fieldType); + } + + public void testNoExecutionHintWithByteValues() throws IOException { + MappedFieldType fieldType = new KeywordFieldMapper.KeywordFieldType("field"); + final CardinalityAggregationBuilder aggregationBuilder = new CardinalityAggregationBuilder("_name").field("field") + .executionHint("direct"); + testAggregationExecutionHint(aggregationBuilder, new MatchAllDocsQuery(), iw -> { + iw.addDocument(singleton(new SortedDocValuesField("field", new BytesRef()))); + }, card -> { + assertEquals(1, card.getValue(), 0); + assertTrue(AggregationInspectionHelper.hasValue(card)); + }, collector -> { assertTrue(collector instanceof CardinalityAggregator.OrdinalsCollector); }, fieldType); + } + + public void testDirectExecutionHintWithByteValues() throws IOException { + MappedFieldType fieldType = new KeywordFieldMapper.KeywordFieldType("field"); + final CardinalityAggregationBuilder aggregationBuilder = new CardinalityAggregationBuilder("_name").field("field") + .executionHint("direct"); + testAggregationExecutionHint(aggregationBuilder, new MatchAllDocsQuery(), iw -> { + iw.addDocument(singleton(new SortedDocValuesField("field", new BytesRef()))); + }, card -> { + assertEquals(1, card.getValue(), 0); + assertTrue(AggregationInspectionHelper.hasValue(card)); + }, collector -> { assertTrue(collector instanceof CardinalityAggregator.OrdinalsCollector); }, fieldType); + } + + public void testOrdinalsExecutionHintWithByteValues() throws IOException { + MappedFieldType fieldType = new KeywordFieldMapper.KeywordFieldType("field"); + final CardinalityAggregationBuilder aggregationBuilder = new CardinalityAggregationBuilder("_name").field("field") + .executionHint("ordinals"); + testAggregationExecutionHint(aggregationBuilder, new MatchAllDocsQuery(), iw -> { + iw.addDocument(singleton(new SortedDocValuesField("field", new BytesRef()))); + }, card -> { + assertEquals(1, card.getValue(), 0); + assertTrue(AggregationInspectionHelper.hasValue(card)); + }, collector -> { assertTrue(collector instanceof CardinalityAggregator.OrdinalsCollector); }, fieldType); + } } diff --git a/test/framework/src/main/java/org/opensearch/search/aggregations/AggregatorTestCase.java b/test/framework/src/main/java/org/opensearch/search/aggregations/AggregatorTestCase.java index 78e3d4f50a0d5..eba1769ad882d 100644 --- a/test/framework/src/main/java/org/opensearch/search/aggregations/AggregatorTestCase.java +++ b/test/framework/src/main/java/org/opensearch/search/aggregations/AggregatorTestCase.java @@ -1331,6 +1331,7 @@ protected void doWriteTo(StreamOutput out) throws IOException { protected static class CountingAggregator extends Aggregator { private final AtomicInteger collectCounter; public final Aggregator delegate; + private LeafBucketCollector selectedCollector; public CountingAggregator(AtomicInteger collectCounter, Aggregator delegate) { this.collectCounter = collectCounter; @@ -1341,6 +1342,10 @@ public AtomicInteger getCollectCount() { return collectCounter; } + public LeafBucketCollector getSelectedCollector() { + return selectedCollector; + } + @Override public void close() { delegate.close(); @@ -1381,7 +1386,8 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx) throws IOExce return new LeafBucketCollector() { @Override public void collect(int doc, long bucket) throws IOException { - delegate.getLeafCollector(ctx).collect(doc, bucket); + selectedCollector = delegate.getLeafCollector(ctx); + selectedCollector.collect(doc, bucket); collectCounter.incrementAndGet(); } };