From 90678c2d76ed1c31aebc85a23fe5be12f6e81417 Mon Sep 17 00:00:00 2001 From: Sorabh Date: Fri, 16 Jun 2023 18:22:35 -0700 Subject: [PATCH] With only GlobalAggregation in request causes unnecessary wrapping with MultiCollector (#8125) Signed-off-by: Sorabh Hamirwasia --- CHANGELOG.md | 1 + .../aggregation/AggregationProfilerIT.java | 76 ++++++++++++++++++- .../opensearch/search/query/QueryPhase.java | 26 ++++--- 3 files changed, 90 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 76f0069dfa079..b0ccf9ec008b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -95,6 +95,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Replaces ZipInputStream with ZipFile to fix Zip Slip vulnerability ([#7230](https://github.com/opensearch-project/OpenSearch/pull/7230)) - Add missing validation/parsing of SearchBackpressureMode of SearchBackpressureSettings ([#7541](https://github.com/opensearch-project/OpenSearch/pull/7541)) - Fix mapping char_filter when mapping a hashtag ([#7591](https://github.com/opensearch-project/OpenSearch/pull/7591)) +- With only GlobalAggregation in request causes unnecessary wrapping with MultiCollector ([#8125](https://github.com/opensearch-project/OpenSearch/pull/8125)) ### Security diff --git a/server/src/internalClusterTest/java/org/opensearch/search/profile/aggregation/AggregationProfilerIT.java b/server/src/internalClusterTest/java/org/opensearch/search/profile/aggregation/AggregationProfilerIT.java index 0e9e409efae59..0f08c537d74d8 100644 --- a/server/src/internalClusterTest/java/org/opensearch/search/profile/aggregation/AggregationProfilerIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/search/profile/aggregation/AggregationProfilerIT.java @@ -32,14 +32,19 @@ package org.opensearch.search.profile.aggregation; +import org.hamcrest.core.IsNull; import org.opensearch.action.index.IndexRequestBuilder; import org.opensearch.action.search.SearchResponse; import org.opensearch.search.aggregations.Aggregator.SubAggCollectionMode; import org.opensearch.search.aggregations.BucketOrder; +import org.opensearch.search.aggregations.InternalAggregation; +import org.opensearch.search.aggregations.bucket.global.Global; import org.opensearch.search.aggregations.bucket.sampler.DiversifiedOrdinalsSamplerAggregator; import org.opensearch.search.aggregations.bucket.terms.GlobalOrdinalsStringTermsAggregator; +import org.opensearch.search.aggregations.metrics.Stats; import org.opensearch.search.profile.ProfileResult; import org.opensearch.search.profile.ProfileShardResult; +import org.opensearch.search.profile.query.QueryProfileShardResult; import org.opensearch.test.OpenSearchIntegTestCase; import java.util.ArrayList; @@ -48,11 +53,15 @@ import java.util.Set; import java.util.stream.Collectors; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; import static org.opensearch.search.aggregations.AggregationBuilders.avg; import static org.opensearch.search.aggregations.AggregationBuilders.diversifiedSampler; +import static org.opensearch.search.aggregations.AggregationBuilders.global; import static org.opensearch.search.aggregations.AggregationBuilders.histogram; import static org.opensearch.search.aggregations.AggregationBuilders.max; +import static org.opensearch.search.aggregations.AggregationBuilders.stats; import static org.opensearch.search.aggregations.AggregationBuilders.terms; import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertAcked; import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertSearchResponse; @@ -95,6 +104,7 @@ public class AggregationProfilerIT extends OpenSearchIntegTestCase { private static final String NUMBER_FIELD = "number"; private static final String TAG_FIELD = "tag"; private static final String STRING_FIELD = "string_field"; + private final int numDocs = 5; @Override protected int numberOfShards() { @@ -118,7 +128,7 @@ protected void setupSuiteScopeCluster() throws Exception { randomStrings[i] = randomAlphaOfLength(10); } - for (int i = 0; i < 5; i++) { + for (int i = 0; i < numDocs; i++) { builders.add( client().prepareIndex("idx") .setSource( @@ -633,4 +643,68 @@ public void testNoProfile() { assertThat(profileResults, notNullValue()); assertThat(profileResults.size(), equalTo(0)); } + + public void testGlobalAggWithStatsSubAggregatorProfile() { + boolean profileEnabled = true; + SearchResponse response = client().prepareSearch("idx") + .addAggregation(global("global").subAggregation(stats("value_stats").field(NUMBER_FIELD))) + .setProfile(profileEnabled) + .get(); + + assertSearchResponse(response); + + Global global = response.getAggregations().get("global"); + assertThat(global, IsNull.notNullValue()); + assertThat(global.getName(), equalTo("global")); + assertThat(global.getDocCount(), equalTo((long) numDocs)); + assertThat((long) ((InternalAggregation) global).getProperty("_count"), equalTo((long) numDocs)); + assertThat(global.getAggregations().asList().isEmpty(), is(false)); + + Stats stats = global.getAggregations().get("value_stats"); + assertThat((Stats) ((InternalAggregation) global).getProperty("value_stats"), sameInstance(stats)); + assertThat(stats, IsNull.notNullValue()); + assertThat(stats.getName(), equalTo("value_stats")); + + Map profileResults = response.getProfileResults(); + assertThat(profileResults, notNullValue()); + assertThat(profileResults.size(), equalTo(getNumShards("idx").numPrimaries)); + for (ProfileShardResult profileShardResult : profileResults.values()) { + assertThat(profileShardResult, notNullValue()); + List queryProfileShardResults = profileShardResult.getQueryProfileResults(); + assertEquals(queryProfileShardResults.size(), 2); + // ensure there is no multi collector getting added with only global agg + for (QueryProfileShardResult queryProfileShardResult : queryProfileShardResults) { + assertEquals(queryProfileShardResult.getQueryResults().size(), 1); + if (queryProfileShardResult.getQueryResults().get(0).getQueryName().equals("MatchAllDocsQuery")) { + assertEquals(0, queryProfileShardResult.getQueryResults().get(0).getProfiledChildren().size()); + assertEquals("search_top_hits", queryProfileShardResult.getCollectorResult().getReason()); + assertEquals(0, queryProfileShardResult.getCollectorResult().getProfiledChildren().size()); + } else if (queryProfileShardResult.getQueryResults().get(0).getQueryName().equals("ConstantScoreQuery")) { + assertEquals(1, queryProfileShardResult.getQueryResults().get(0).getProfiledChildren().size()); + assertEquals("aggregation_global", queryProfileShardResult.getCollectorResult().getReason()); + assertEquals(0, queryProfileShardResult.getCollectorResult().getProfiledChildren().size()); + } else { + fail("unexpected profile shard result in the response"); + } + } + AggregationProfileShardResult aggProfileResults = profileShardResult.getAggregationProfileResults(); + assertThat(aggProfileResults, notNullValue()); + List aggProfileResultsList = aggProfileResults.getProfileResults(); + assertThat(aggProfileResultsList, notNullValue()); + assertEquals(1, aggProfileResultsList.size()); + ProfileResult globalAggResult = aggProfileResultsList.get(0); + assertThat(globalAggResult, notNullValue()); + assertEquals("GlobalAggregator", globalAggResult.getQueryName()); + assertEquals("global", globalAggResult.getLuceneDescription()); + assertEquals(1, globalAggResult.getProfiledChildren().size()); + assertThat(globalAggResult.getTime(), greaterThan(0L)); + Map breakdown = globalAggResult.getTimeBreakdown(); + assertThat(breakdown, notNullValue()); + assertEquals(BREAKDOWN_KEYS, breakdown.keySet()); + assertThat(breakdown.get(INITIALIZE), greaterThan(0L)); + assertThat(breakdown.get(COLLECT), greaterThan(0L)); + assertThat(breakdown.get(BUILD_AGGREGATION).longValue(), greaterThan(0L)); + assertEquals(0, breakdown.get(REDUCE).intValue()); + } + } } diff --git a/server/src/main/java/org/opensearch/search/query/QueryPhase.java b/server/src/main/java/org/opensearch/search/query/QueryPhase.java index 06be6683b5e4c..069b410280d63 100644 --- a/server/src/main/java/org/opensearch/search/query/QueryPhase.java +++ b/server/src/main/java/org/opensearch/search/query/QueryPhase.java @@ -39,6 +39,7 @@ import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.Collector; +import org.apache.lucene.search.CollectorManager; import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; @@ -71,6 +72,7 @@ import java.io.IOException; import java.util.LinkedList; +import java.util.List; import java.util.Map; import java.util.Objects; import java.util.concurrent.ExecutorService; @@ -234,19 +236,19 @@ static boolean executeInternal(SearchContext searchContext, QueryPhaseSearcher q // this collector can filter documents during the collection hasFilterCollector = true; } - if (searchContext.queryCollectorManagers().isEmpty() == false) { - // plug in additional collectors, like aggregations except global aggregations - collectors.add( - createMultiCollectorContext( - searchContext.queryCollectorManagers() - .entrySet() - .stream() - .filter(entry -> !(entry.getKey().equals(GlobalAggCollectorManager.class))) - .map(Map.Entry::getValue) - .collect(Collectors.toList()) - ) - ); + + // plug in additional collectors, like aggregations except global aggregations + final List> managersExceptGlobalAgg = searchContext + .queryCollectorManagers() + .entrySet() + .stream() + .filter(entry -> !(entry.getKey().equals(GlobalAggCollectorManager.class))) + .map(Map.Entry::getValue) + .collect(Collectors.toList()); + if (managersExceptGlobalAgg.isEmpty() == false) { + collectors.add(createMultiCollectorContext(managersExceptGlobalAgg)); } + if (searchContext.minimumScore() != null) { // apply the minimum score after multi collector so we filter aggs as well collectors.add(createMinScoreCollectorContext(searchContext.minimumScore()));