Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Avoid NPE if on SnapshotInfo if 'shallow' boolean not present ([#18187](https://github.com/opensearch-project/OpenSearch/issues/18187))
- Fix 'system call filter not installed' caused when network.host: 0.0.0.0 ([#18309](https://github.com/opensearch-project/OpenSearch/pull/18309))
- Fix MatrixStatsAggregator reuse when mode parameter changes ([#18242](https://github.com/opensearch-project/OpenSearch/issues/18242))
- Add cancellation check in aggregation code paths at shard and co-ordinator node ([#18386](https://github.com/opensearch-project/OpenSearch/pull/18386))

### Security

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,15 @@ public class TermsReduceBenchmark {
private final NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(searchModule.getNamedWriteables());
private final SearchPhaseController controller = new SearchPhaseController(
namedWriteableRegistry,
req -> new InternalAggregation.ReduceContextBuilder() {
(req, isTaskCancelled) -> new InternalAggregation.ReduceContextBuilder() {
@Override
public InternalAggregation.ReduceContext forPartialReduction() {
return InternalAggregation.ReduceContext.forPartialReduction(null, null, () -> PipelineAggregator.PipelineTree.EMPTY);
return InternalAggregation.ReduceContext.forPartialReduction(
null,
null,
() -> PipelineAggregator.PipelineTree.EMPTY,
isTaskCancelled
);
}

@Override
Expand All @@ -114,7 +119,8 @@ public InternalAggregation.ReduceContext forFinalReduction() {
null,
null,
bucketConsumer,
PipelineAggregator.PipelineTree.EMPTY
PipelineAggregator.PipelineTree.EMPTY,
isTaskCancelled
);
}
}
Expand Down Expand Up @@ -229,7 +235,8 @@ public SearchPhaseController.ReducedQueryPhase reduceAggs(TermsList candidateLis
SearchProgressListener.NOOP,
namedWriteableRegistry,
shards.size(),
exc -> {}
exc -> {},
() -> false
);
CountDownLatch latch = new CountDownLatch(shards.size());
for (int i = 0; i < shards.size(); i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ public Object getProperty(List<String> path) {

@Override
public InternalAggregation reduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
checkCancelled(reduceContext);
// merge stats across all shards
List<InternalAggregation> aggs = new ArrayList<>(aggregations);
aggs.removeIf(p -> ((InternalMatrixStats) p).stats == null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ public void testReduceRandom() {
bigArrays,
mockScriptService,
b -> {},
PipelineTree.EMPTY
PipelineTree.EMPTY,
() -> false
);
InternalMatrixStats reduced = (InternalMatrixStats) shardResults.get(0).reduce(shardResults, context);
multiPassStats.assertNearlyEqual(reduced.getResults());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,15 @@ public List<BaseGeoGridBucket> getBuckets() {

@Override
public BaseGeoGrid reduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
checkCancelled(reduceContext);
LongObjectPagedHashMap<List<BaseGeoGridBucket>> buckets = null;
for (InternalAggregation aggregation : aggregations) {
BaseGeoGrid grid = (BaseGeoGrid) aggregation;
if (buckets == null) {
buckets = new LongObjectPagedHashMap<>(grid.buckets.size(), reduceContext.bigArrays());
}
for (Object obj : grid.buckets) {
checkCancelled(reduceContext);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does checkingCancelled for every bucket adds too much overhead? How about sampling after specific number of iterations or something?

BaseGeoGridBucket bucket = (BaseGeoGridBucket) obj;
List<BaseGeoGridBucket> existingBuckets = buckets.get(bucket.hashAsLong());
if (existingBuckets == null) {
Expand All @@ -114,6 +116,7 @@ public BaseGeoGrid reduce(List<InternalAggregation> aggregations, ReduceContext
final int size = Math.toIntExact(reduceContext.isFinalReduce() == false ? buckets.size() : Math.min(requiredSize, buckets.size()));
BucketPriorityQueue<BaseGeoGridBucket> ordered = new BucketPriorityQueue<>(size);
for (LongObjectPagedHashMap.Cursor<List<BaseGeoGridBucket>> cursor : buckets) {
checkCancelled(reduceContext);
List<BaseGeoGridBucket> sameCellBuckets = cursor.value;
ordered.insertWithOverflow(reduceBucket(sameCellBuckets, reduceContext));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ public String getWriteableName() {

@Override
public InternalAggregation reduce(List<InternalAggregation> aggregations, ReduceContext reduceContext) {
checkCancelled(reduceContext);
double top = Double.NEGATIVE_INFINITY;
double bottom = Double.POSITIVE_INFINITY;
double posLeft = Double.POSITIVE_INFINITY;
Expand All @@ -119,6 +120,7 @@ public InternalAggregation reduce(List<InternalAggregation> aggregations, Reduce
double negRight = Double.NEGATIVE_INFINITY;

for (InternalAggregation aggregation : aggregations) {
checkCancelled(reduceContext);
InternalGeoBounds bounds = (InternalGeoBounds) aggregation;

if (bounds.top > top) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,10 +332,10 @@ private void testCase(
MappedFieldType fieldType = new GeoPointFieldMapper.GeoPointFieldType(FIELD_NAME);

Aggregator aggregator = createAggregator(aggregationBuilder, indexSearcher, fieldType);
aggregator.preCollection();
aggregator.preCollection(() -> {});
indexSearcher.search(query, aggregator);
aggregator.postCollection();
verify.accept((BaseGeoGrid<T>) aggregator.buildTopLevel());
aggregator.postCollection(() -> {});
verify.accept((BaseGeoGrid<T>) aggregator.buildTopLevel(() -> {}));

indexReader.close();
directory.close();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ public void collect(int docId, long owningBucketOrd) throws IOException {
}

@Override
public void postCollection() throws IOException {
public void postCollection(Runnable checkCancelled) throws IOException {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is something that might be overridden by subclass in plugins. We need to ensure this is not breaking those

// Delaying until beforeBuildingBuckets
}

Expand Down Expand Up @@ -177,7 +177,7 @@ public float score() {
}
}
}
super.postCollection(); // Run post collection after collecting the sub-aggs
super.postCollection(() -> {}); // Run post collection after collecting the sub-aggs
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -682,9 +682,9 @@ public ScoreMode scoreMode() {
}

@Override
public void preCollection() throws IOException {}
public void preCollection(Runnable r) throws IOException {}

@Override
public void postCollection() throws IOException {}
public void postCollection(Runnable r) throws IOException {}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BooleanSupplier;
import java.util.function.Consumer;

/**
Expand All @@ -78,6 +79,7 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhas
private final SearchProgressListener progressListener;
private final ReduceContextBuilder aggReduceContextBuilder;
private final NamedWriteableRegistry namedWriteableRegistry;
private final BooleanSupplier isTaskCancelled;

private final int topNSize;
private final boolean hasTopDocs;
Expand All @@ -99,14 +101,15 @@ public QueryPhaseResultConsumer(
SearchProgressListener progressListener,
NamedWriteableRegistry namedWriteableRegistry,
int expectedResultSize,
Consumer<Exception> onPartialMergeFailure
Consumer<Exception> onPartialMergeFailure,
BooleanSupplier isTaskCancelled
) {
super(expectedResultSize);
this.executor = executor;
this.circuitBreaker = circuitBreaker;
this.controller = controller;
this.progressListener = progressListener;
this.aggReduceContextBuilder = controller.getReduceContext(request);
this.aggReduceContextBuilder = controller.getReduceContext(request, isTaskCancelled);
this.namedWriteableRegistry = namedWriteableRegistry;
this.topNSize = SearchPhaseController.getTopDocsSize(request);
this.performFinalReduce = request.isFinalReduce();
Expand All @@ -117,6 +120,7 @@ public QueryPhaseResultConsumer(
this.hasAggs = source != null && source.aggregations() != null;
int batchReduceSize = (hasAggs || hasTopDocs) ? Math.min(request.getBatchedReduceSize(), expectedResultSize) : expectedResultSize;
this.pendingMerges = new PendingMerges(batchReduceSize, request.resolveTrackTotalHitsUpTo());
this.isTaskCancelled = isTaskCancelled;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,9 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executor;
import java.util.function.BiFunction;
import java.util.function.BooleanSupplier;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.IntFunction;
import java.util.stream.Collectors;

Expand All @@ -92,11 +93,14 @@ public final class SearchPhaseController {
private static final ScoreDoc[] EMPTY_DOCS = new ScoreDoc[0];

private final NamedWriteableRegistry namedWriteableRegistry;
private final Function<SearchSourceBuilder, InternalAggregation.ReduceContextBuilder> requestToAggReduceContextBuilder;
private final BiFunction<
SearchSourceBuilder,
BooleanSupplier,
InternalAggregation.ReduceContextBuilder> requestToAggReduceContextBuilder;

public SearchPhaseController(
NamedWriteableRegistry namedWriteableRegistry,
Function<SearchSourceBuilder, InternalAggregation.ReduceContextBuilder> requestToAggReduceContextBuilder
BiFunction<SearchSourceBuilder, BooleanSupplier, InternalAggregation.ReduceContextBuilder> requestToAggReduceContextBuilder
) {
this.namedWriteableRegistry = namedWriteableRegistry;
this.requestToAggReduceContextBuilder = requestToAggReduceContextBuilder;
Expand Down Expand Up @@ -744,8 +748,8 @@ public InternalSearchResponse buildResponse(SearchHits hits) {
}
}

InternalAggregation.ReduceContextBuilder getReduceContext(SearchRequest request) {
return requestToAggReduceContextBuilder.apply(request.source());
InternalAggregation.ReduceContextBuilder getReduceContext(SearchRequest request, BooleanSupplier isTaskCancelled) {
return requestToAggReduceContextBuilder.apply(request.source(), isTaskCancelled);
}

/**
Expand All @@ -757,7 +761,8 @@ QueryPhaseResultConsumer newSearchPhaseResults(
SearchProgressListener listener,
SearchRequest request,
int numShards,
Consumer<Exception> onPartialMergeFailure
Consumer<Exception> onPartialMergeFailure,
BooleanSupplier isCancelled
) {
return new QueryPhaseResultConsumer(
request,
Expand All @@ -767,7 +772,8 @@ QueryPhaseResultConsumer newSearchPhaseResults(
listener,
namedWriteableRegistry,
numShards,
onPartialMergeFailure
onPartialMergeFailure,
isCancelled
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.BooleanSupplier;
import java.util.function.Function;
import java.util.function.LongSupplier;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -514,6 +515,12 @@ private ActionListener<SearchSourceBuilder> buildRewriteListener(
);
}
OriginalIndices localIndices = remoteClusterIndices.remove(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY);
BooleanSupplier isRequestCancelled = () -> {
if (task instanceof CancellableTask) {
return ((CancellableTask) task).isCancelled();
}
return false;
};
Comment on lines +518 to +523
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be static lambda somewhere? Doesn't have any logic specific to this class?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is capturing the task instance from the enclosing scope

if (remoteClusterIndices.isEmpty()) {
executeLocalSearch(
task,
Expand All @@ -533,7 +540,7 @@ private ActionListener<SearchSourceBuilder> buildRewriteListener(
localIndices,
remoteClusterIndices,
timeProvider,
searchService.aggReduceContextBuilder(searchRequest.source()),
searchService.aggReduceContextBuilder(searchRequest.source(), isRequestCancelled),
remoteClusterService,
threadPool,
listener,
Expand Down Expand Up @@ -1265,7 +1272,8 @@ private AbstractSearchAsyncAction<? extends SearchPhaseResult> searchAsyncAction
task.getProgressListener(),
searchRequest,
shardIterators.size(),
exc -> cancelTask(task, exc)
exc -> cancelTask(task, exc),
task::isCancelled
);
AbstractSearchAsyncAction<? extends SearchPhaseResult> searchAsyncAction;
switch (searchRequest.searchType()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.function.Function;
import java.util.function.BiFunction;
import java.util.function.BooleanSupplier;
import java.util.function.LongSupplier;

import static org.opensearch.search.SearchService.AGGREGATION_REWRITE_FILTER_SEGMENT_THRESHOLD;
Expand Down Expand Up @@ -206,7 +207,10 @@ final class DefaultSearchContext extends SearchContext {
private final Map<Class<?>, CollectorManager<? extends Collector, ReduceableSearchResult>> queryCollectorManagers = new HashMap<>();
private final QueryShardContext queryShardContext;
private final FetchPhase fetchPhase;
private final Function<SearchSourceBuilder, InternalAggregation.ReduceContextBuilder> requestToAggReduceContextBuilder;
private final BiFunction<
SearchSourceBuilder,
BooleanSupplier,
InternalAggregation.ReduceContextBuilder> requestToAggReduceContextBuilder;
private final String concurrentSearchMode;
private final SetOnce<Boolean> requestShouldUseConcurrentSearch = new SetOnce<>();
private final int maxAggRewriteFilters;
Expand All @@ -227,7 +231,7 @@ final class DefaultSearchContext extends SearchContext {
Version minNodeVersion,
boolean validate,
Executor executor,
Function<SearchSourceBuilder, InternalAggregation.ReduceContextBuilder> requestToAggReduceContextBuilder,
BiFunction<SearchSourceBuilder, BooleanSupplier, InternalAggregation.ReduceContextBuilder> requestToAggReduceContextBuilder,
Collection<ConcurrentSearchRequestDecider.Factory> concurrentSearchDeciderFactories
) throws IOException {
this.readerContext = readerContext;
Expand Down Expand Up @@ -1042,7 +1046,8 @@ public ReaderContext readerContext() {

@Override
public InternalAggregation.ReduceContext partialOnShard() {
InternalAggregation.ReduceContext rc = requestToAggReduceContextBuilder.apply(request.source()).forPartialReduction();
InternalAggregation.ReduceContext rc = requestToAggReduceContextBuilder.apply(request.source(), this::isCancelled)
.forPartialReduction();
rc.setSliceLevel(shouldUseConcurrentSearch());
return rc;
}
Expand Down
12 changes: 9 additions & 3 deletions server/src/main/java/org/opensearch/search/SearchService.java
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.BooleanSupplier;
import java.util.function.LongSupplier;

import static org.opensearch.common.unit.TimeValue.timeValueHours;
Expand Down Expand Up @@ -1804,14 +1805,18 @@ public IndicesService getIndicesService() {
* Returns a builder for {@link InternalAggregation.ReduceContext}. This
* builder retains a reference to the provided {@link SearchSourceBuilder}.
*/
public InternalAggregation.ReduceContextBuilder aggReduceContextBuilder(SearchSourceBuilder searchSourceBuilder) {
public InternalAggregation.ReduceContextBuilder aggReduceContextBuilder(
SearchSourceBuilder searchSourceBuilder,
BooleanSupplier isRequestCancelled
) {
return new InternalAggregation.ReduceContextBuilder() {
@Override
public InternalAggregation.ReduceContext forPartialReduction() {
return InternalAggregation.ReduceContext.forPartialReduction(
bigArrays,
scriptService,
() -> requestToPipelineTree(searchSourceBuilder)
() -> requestToPipelineTree(searchSourceBuilder),
isRequestCancelled
);
}

Expand All @@ -1822,7 +1827,8 @@ public ReduceContext forFinalReduction() {
bigArrays,
scriptService,
multiBucketConsumerService.create(),
pipelineTree
pipelineTree,
isRequestCancelled
);
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ public String getCollectorReason() {

@Override
public ReduceableSearchResult reduce(Collection<Collector> collectors) throws IOException {
final List<InternalAggregation> internals = context.bucketCollectorProcessor().toInternalAggregations(collectors);
final List<InternalAggregation> internals = context.bucketCollectorProcessor()
.toInternalAggregations(collectors, context::isCancelled);
assert internals.stream().noneMatch(Objects::isNull);
context.aggregations().resetBucketMultiConsumer();

Expand All @@ -70,7 +71,7 @@ protected AggregationReduceableSearchResult buildAggregationResult(InternalAggre

static Collector createCollector(List<Aggregator> collectors) throws IOException {
Collector collector = MultiBucketCollector.wrap(collectors);
((BucketCollector) collector).preCollection();
((BucketCollector) collector).preCollection(() -> {});
return collector;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,9 @@ public interface BucketComparator {
* of another aggregation then the aggregation that contains
* it will call {@link #buildAggregations(long[])}.
*/
public final InternalAggregation buildTopLevel() throws IOException {
public final InternalAggregation buildTopLevel(Runnable checkCancelled) throws IOException {
assert parent() == null;
checkCancelled.run();
this.internalAggregation.set(buildAggregations(new long[] { 0 })[0]);
return internalAggregation.get();
}
Expand Down
Loading
Loading