Skip to content

Commit

Permalink
Query shape for agg & sort
Browse files Browse the repository at this point in the history
Signed-off-by: David Zane <davizane@amazon.com>
  • Loading branch information
dzane17 committed Jul 29, 2024
1 parent c7fb34d commit a40fad5
Show file tree
Hide file tree
Showing 4 changed files with 561 additions and 21 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.plugin.insights.core.service.categorizer;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import org.opensearch.core.common.io.stream.NamedWriteable;
import org.opensearch.index.query.AbstractGeometryQueryBuilder;
import org.opensearch.index.query.CommonTermsQueryBuilder;
import org.opensearch.index.query.ExistsQueryBuilder;
import org.opensearch.index.query.FieldMaskingSpanQueryBuilder;
import org.opensearch.index.query.FuzzyQueryBuilder;
import org.opensearch.index.query.GeoDistanceQueryBuilder;
import org.opensearch.index.query.GeoPolygonQueryBuilder;
import org.opensearch.index.query.MatchBoolPrefixQueryBuilder;
import org.opensearch.index.query.MatchPhrasePrefixQueryBuilder;
import org.opensearch.index.query.MatchPhraseQueryBuilder;
import org.opensearch.index.query.MatchQueryBuilder;
import org.opensearch.index.query.MultiTermQueryBuilder;
import org.opensearch.index.query.PrefixQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.RangeQueryBuilder;
import org.opensearch.index.query.RegexpQueryBuilder;
import org.opensearch.index.query.SpanNearQueryBuilder;
import org.opensearch.index.query.SpanTermQueryBuilder;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.index.query.TermsQueryBuilder;
import org.opensearch.index.query.WildcardQueryBuilder;
import org.opensearch.search.aggregations.AggregationBuilder;
import org.opensearch.search.aggregations.AggregatorFactories;
import org.opensearch.search.aggregations.PipelineAggregationBuilder;
import org.opensearch.search.aggregations.bucket.histogram.AutoDateHistogramAggregationBuilder;
import org.opensearch.search.aggregations.bucket.histogram.DateHistogramAggregationBuilder;
import org.opensearch.search.aggregations.bucket.histogram.HistogramAggregationBuilder;
import org.opensearch.search.aggregations.bucket.histogram.VariableWidthHistogramAggregationBuilder;
import org.opensearch.search.aggregations.bucket.missing.MissingAggregationBuilder;
import org.opensearch.search.aggregations.bucket.range.AbstractRangeBuilder;
import org.opensearch.search.aggregations.bucket.range.GeoDistanceAggregationBuilder;
import org.opensearch.search.aggregations.bucket.range.IpRangeAggregationBuilder;
import org.opensearch.search.aggregations.bucket.sampler.DiversifiedAggregationBuilder;
import org.opensearch.search.aggregations.bucket.terms.RareTermsAggregationBuilder;
import org.opensearch.search.aggregations.bucket.terms.SignificantTermsAggregationBuilder;
import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.opensearch.search.aggregations.metrics.AvgAggregationBuilder;
import org.opensearch.search.aggregations.metrics.CardinalityAggregationBuilder;
import org.opensearch.search.aggregations.metrics.ExtendedStatsAggregationBuilder;
import org.opensearch.search.aggregations.metrics.GeoCentroidAggregationBuilder;
import org.opensearch.search.aggregations.metrics.MaxAggregationBuilder;
import org.opensearch.search.aggregations.metrics.MinAggregationBuilder;
import org.opensearch.search.aggregations.metrics.StatsAggregationBuilder;
import org.opensearch.search.aggregations.metrics.SumAggregationBuilder;
import org.opensearch.search.aggregations.metrics.ValueCountAggregationBuilder;
import org.opensearch.search.aggregations.support.ValuesSourceAggregationBuilder;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.sort.FieldSortBuilder;
import org.opensearch.search.sort.SortBuilder;

/**
* Class to generate query shape
*/
public class QueryShapeGenerator {
static final String TWO_SPACE_INDENT = " ";
static final Map<Class<?>, List<Function<Object, String>>> QUERY_FIELD_DATA_MAP = FieldDataMapHelper.getQueryFieldDataMap();
static final Map<Class<?>, List<Function<Object, String>>> AGG_FIELD_DATA_MAP = FieldDataMapHelper.getAggFieldDataMap();
static final Map<Class<?>, List<Function<Object, String>>> SORT_FIELD_DATA_MAP = FieldDataMapHelper.getSortFieldDataMap();

public static String buildShape(SearchSourceBuilder source, Boolean showFields) {
StringBuilder shape = new StringBuilder();
shape.append(buildQueryShape(source.query(), showFields));
shape.append(buildAggregationShape(source.aggregations(), showFields));
shape.append(buildSortShape(source.sorts(), showFields));
return shape.toString();
}

static String buildQueryShape(QueryBuilder queryBuilder, Boolean showFields) {
if (queryBuilder == null) {
return "";
}
QueryShapeVisitor shapeVisitor = new QueryShapeVisitor();
queryBuilder.visit(shapeVisitor);
return shapeVisitor.prettyPrintTree("", showFields);
}

static String buildAggregationShape(AggregatorFactories.Builder aggregationsBuilder, Boolean showFields) {
if (aggregationsBuilder == null) {
return "";
}
StringBuilder aggregationShape = recursiveAggregationShapeBuilder(
aggregationsBuilder.getAggregatorFactories(),
aggregationsBuilder.getPipelineAggregatorFactories(),
new StringBuilder(),
0,
showFields
);
return aggregationShape.toString();
}

static StringBuilder recursiveAggregationShapeBuilder(
Collection<AggregationBuilder> aggregationBuilders,
Collection<PipelineAggregationBuilder> pipelineAggregations,
StringBuilder outputBuilder,
int indentCount,
Boolean showFields
) {
String baseIndent = TWO_SPACE_INDENT.repeat(indentCount);

//// Normal Aggregations ////
if (aggregationBuilders.isEmpty() == false) {
outputBuilder.append(baseIndent).append("aggregation:").append("\n");
}
List<String> aggShapeStrings = new ArrayList<>();
for (AggregationBuilder aggBuilder : aggregationBuilders) {
StringBuilder stringBuilder = new StringBuilder();
stringBuilder.append(baseIndent).append(TWO_SPACE_INDENT).append(aggBuilder.getType());
if (showFields) {
stringBuilder.append(buildFieldDataString(AGG_FIELD_DATA_MAP.get(aggBuilder.getClass()), aggBuilder));
}
stringBuilder.append("\n");

if (aggBuilder.getSubAggregations().isEmpty() == false) {
// Recursive call on sub-aggregations
recursiveAggregationShapeBuilder(
aggBuilder.getSubAggregations(),
aggBuilder.getPipelineAggregations(),
stringBuilder,
indentCount + 2,
showFields
);
}
aggShapeStrings.add(stringBuilder.toString());
}

// Sort alphanumerically and append aggregations list
Collections.sort(aggShapeStrings);
for (String shapeString : aggShapeStrings) {
outputBuilder.append(shapeString);
}

//// Pipeline Aggregation (cannot have sub-aggregations) ////
if (pipelineAggregations.isEmpty() == false) {
outputBuilder.append(baseIndent).append(TWO_SPACE_INDENT).append("pipeline aggregation:").append("\n");

List<String> pipelineAggShapeStrings = new ArrayList<>();
for (PipelineAggregationBuilder pipelineAgg : pipelineAggregations) {
pipelineAggShapeStrings.add(baseIndent + TWO_SPACE_INDENT.repeat(2) + pipelineAgg.getType() + "\n");
}

// Sort alphanumerically and append pipeline aggregations list
Collections.sort(pipelineAggShapeStrings);
for (String shapeString : pipelineAggShapeStrings) {
outputBuilder.append(shapeString);
}
}

return outputBuilder;
}

static String buildSortShape(List<SortBuilder<?>> sortBuilderList, Boolean showFields) {
if (sortBuilderList == null || sortBuilderList.isEmpty()) {
return "";
}
StringBuilder sortShape = new StringBuilder();
sortShape.append("sort:\n");

List<String> shapeStrings = new ArrayList<>();
for (SortBuilder<?> sortBuilder : sortBuilderList) {
StringBuilder stringBuilder = new StringBuilder();
stringBuilder.append(TWO_SPACE_INDENT).append(sortBuilder.order());
if (showFields) {
stringBuilder.append(buildFieldDataString(SORT_FIELD_DATA_MAP.get(sortBuilder.getClass()), sortBuilder));
}
shapeStrings.add(stringBuilder.toString());
}

Collections.sort(shapeStrings);
for (String line : shapeStrings) {
sortShape.append(line).append("\n");
}
return sortShape.toString();
}

static String buildFieldDataString(List<Function<Object, String>> methods, NamedWriteable builder) {
List<String> fieldDataList = new ArrayList<>();
if (methods != null) {
for (Function<Object, String> lambda : methods) {
fieldDataList.add(lambda.apply(builder));
}
}
return String.format(" [%s]", String.join(", ", fieldDataList));
}

/**
* Class to create field data map for query, agg, sort
* Map
*/
public static class FieldDataMapHelper {

// Helper method to create map entries
private static <T> Map.Entry<Class<?>, List<Function<Object, String>>> createEntry(Class<T> clazz, Function<T, String> extractor) {
return Map.entry(clazz, List.of(obj -> extractor.apply(clazz.cast(obj))));
}

// Method to return the QUERY_FIELD_DATA_MAP
public static Map<Class<?>, List<Function<Object, String>>> getQueryFieldDataMap() {
return Map.ofEntries(
createEntry(AbstractGeometryQueryBuilder.class, AbstractGeometryQueryBuilder::fieldName),
createEntry(CommonTermsQueryBuilder.class, CommonTermsQueryBuilder::fieldName),
createEntry(org.opensearch.index.query.ExistsQueryBuilder.class, ExistsQueryBuilder::fieldName),
createEntry(org.opensearch.index.query.FieldMaskingSpanQueryBuilder.class, FieldMaskingSpanQueryBuilder::fieldName),
createEntry(FuzzyQueryBuilder.class, FuzzyQueryBuilder::fieldName),
createEntry(
org.opensearch.index.query.GeoBoundingBoxQueryBuilder.class,
org.opensearch.index.query.GeoBoundingBoxQueryBuilder::fieldName
),
createEntry(org.opensearch.index.query.GeoDistanceQueryBuilder.class, GeoDistanceQueryBuilder::fieldName),
createEntry(GeoPolygonQueryBuilder.class, GeoPolygonQueryBuilder::fieldName),
createEntry(MatchBoolPrefixQueryBuilder.class, MatchBoolPrefixQueryBuilder::fieldName),
createEntry(MatchQueryBuilder.class, MatchQueryBuilder::fieldName),
createEntry(org.opensearch.index.query.MatchPhraseQueryBuilder.class, MatchPhraseQueryBuilder::fieldName),
createEntry(MatchPhrasePrefixQueryBuilder.class, MatchPhrasePrefixQueryBuilder::fieldName),
createEntry(MultiTermQueryBuilder.class, MultiTermQueryBuilder::fieldName),
createEntry(PrefixQueryBuilder.class, PrefixQueryBuilder::fieldName),
createEntry(RangeQueryBuilder.class, RangeQueryBuilder::fieldName),
createEntry(RegexpQueryBuilder.class, RegexpQueryBuilder::fieldName),
createEntry(SpanNearQueryBuilder.SpanGapQueryBuilder.class, SpanNearQueryBuilder.SpanGapQueryBuilder::fieldName),
createEntry(SpanTermQueryBuilder.class, SpanTermQueryBuilder::fieldName),
createEntry(TermQueryBuilder.class, TermQueryBuilder::fieldName),
createEntry(TermsQueryBuilder.class, TermsQueryBuilder::fieldName),
createEntry(WildcardQueryBuilder.class, WildcardQueryBuilder::fieldName)
);
}

// Method to return the AGG_FIELD_DATA_MAP
public static Map<Class<?>, List<Function<Object, String>>> getAggFieldDataMap() {
return Map.ofEntries(
createEntry(IpRangeAggregationBuilder.class, IpRangeAggregationBuilder::field),
createEntry(AutoDateHistogramAggregationBuilder.class, AutoDateHistogramAggregationBuilder::field),
createEntry(DateHistogramAggregationBuilder.class, DateHistogramAggregationBuilder::field),
createEntry(HistogramAggregationBuilder.class, HistogramAggregationBuilder::field),
createEntry(VariableWidthHistogramAggregationBuilder.class, VariableWidthHistogramAggregationBuilder::field),
createEntry(MissingAggregationBuilder.class, MissingAggregationBuilder::field),
createEntry(AbstractRangeBuilder.class, AbstractRangeBuilder::field),
createEntry(GeoDistanceAggregationBuilder.class, GeoDistanceAggregationBuilder::field),
createEntry(DiversifiedAggregationBuilder.class, DiversifiedAggregationBuilder::field),
createEntry(RareTermsAggregationBuilder.class, RareTermsAggregationBuilder::field),
createEntry(SignificantTermsAggregationBuilder.class, SignificantTermsAggregationBuilder::field),
createEntry(TermsAggregationBuilder.class, TermsAggregationBuilder::field),
createEntry(AvgAggregationBuilder.class, AvgAggregationBuilder::field),
createEntry(CardinalityAggregationBuilder.class, CardinalityAggregationBuilder::field),
createEntry(ExtendedStatsAggregationBuilder.class, ExtendedStatsAggregationBuilder::field),
createEntry(GeoCentroidAggregationBuilder.class, GeoCentroidAggregationBuilder::field),
createEntry(MaxAggregationBuilder.class, MaxAggregationBuilder::field),
createEntry(MinAggregationBuilder.class, MinAggregationBuilder::field),
createEntry(StatsAggregationBuilder.class, StatsAggregationBuilder::field),
createEntry(SumAggregationBuilder.class, SumAggregationBuilder::field),
createEntry(ValueCountAggregationBuilder.class, ValueCountAggregationBuilder::field),
createEntry(ValuesSourceAggregationBuilder.class, ValuesSourceAggregationBuilder::field)
);
}

// Method to return the SORT_FIELD_DATA_MAP
public static Map<Class<?>, List<Function<Object, String>>> getSortFieldDataMap() {
return Map.ofEntries(createEntry(FieldSortBuilder.class, FieldSortBuilder::getFieldName));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@

package org.opensearch.plugin.insights.core.service.categorizer;

import static org.opensearch.plugin.insights.core.service.categorizer.QueryShapeGenerator.QUERY_FIELD_DATA_MAP;
import static org.opensearch.plugin.insights.core.service.categorizer.QueryShapeGenerator.TWO_SPACE_INDENT;

import java.util.ArrayList;
import java.util.EnumMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.Function;
import org.apache.lucene.search.BooleanClause;
import org.opensearch.common.SetOnce;
import org.opensearch.index.query.QueryBuilder;
Expand All @@ -23,11 +27,21 @@
*/
public final class QueryShapeVisitor implements QueryBuilderVisitor {
private final SetOnce<String> queryType = new SetOnce<>();
private final SetOnce<String> fieldData = new SetOnce<>();
private final Map<BooleanClause.Occur, List<QueryShapeVisitor>> childVisitors = new EnumMap<>(BooleanClause.Occur.class);

@Override
public void accept(QueryBuilder qb) {
queryType.set(qb.getName());
public void accept(QueryBuilder queryBuilder) {
queryType.set(queryBuilder.getName());

List<String> fieldDataList = new ArrayList<>();
List<Function<Object, String>> methods = QUERY_FIELD_DATA_MAP.get(queryBuilder.getClass());
if (methods != null) {
for (Function<Object, String> lambda : methods) {
fieldDataList.add(lambda.apply(queryBuilder));
}
}
fieldData.set(String.join(", ", fieldDataList));
}

@Override
Expand Down Expand Up @@ -83,12 +97,16 @@ public String toJson() {
* @param indent indent size
* @return Query builder tree as a pretty string
*/
public String prettyPrintTree(String indent) {
StringBuilder outputBuilder = new StringBuilder(indent).append(queryType.get()).append("\n");
public String prettyPrintTree(String indent, Boolean showFields) {
StringBuilder outputBuilder = new StringBuilder(indent).append(queryType.get());
if (showFields) {
outputBuilder.append(" [").append(fieldData.get()).append("]");
}
outputBuilder.append("\n");
for (Map.Entry<BooleanClause.Occur, List<QueryShapeVisitor>> entry : childVisitors.entrySet()) {
outputBuilder.append(indent).append(" ").append(entry.getKey().name().toLowerCase(Locale.ROOT)).append(":\n");
outputBuilder.append(indent).append(TWO_SPACE_INDENT).append(entry.getKey().name().toLowerCase(Locale.ROOT)).append(":\n");
for (QueryShapeVisitor child : entry.getValue()) {
outputBuilder.append(child.prettyPrintTree(indent + " "));
outputBuilder.append(child.prettyPrintTree(indent + TWO_SPACE_INDENT.repeat(2), showFields));
}
}
return outputBuilder.toString();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
*/
public final class SearchQueryCategorizer {

private static final Logger log = LogManager.getLogger(SearchQueryCategorizer.class);
private static final Logger logger = LogManager.getLogger(SearchQueryCategorizer.class);

/**
* Contains all the search query counters
Expand Down Expand Up @@ -83,11 +83,14 @@ public void categorize(SearchQueryRecord record) {
SearchSourceBuilder source = (SearchSourceBuilder) record.getAttributes().get(Attribute.SOURCE);
Map<MetricType, Number> measurements = record.getMeasurements();

QueryBuilder topLevelQueryBuilder = source.query();
logQueryShape(topLevelQueryBuilder);
incrementQueryTypeCounters(topLevelQueryBuilder, measurements);
incrementQueryTypeCounters(source.query(), measurements);
incrementQueryAggregationCounters(source.aggregations(), measurements);
incrementQuerySortCounters(source.sorts(), measurements);

if (logger.isTraceEnabled()) {
String searchShape = QueryShapeGenerator.buildShape(source, true);
logger.trace(searchShape);
}
}

private void incrementQuerySortCounters(List<SortBuilder<?>> sorts, Map<MetricType, Number> measurements) {
Expand Down Expand Up @@ -115,17 +118,6 @@ private void incrementQueryTypeCounters(QueryBuilder topLevelQueryBuilder, Map<M
topLevelQueryBuilder.visit(searchQueryVisitor);
}

private void logQueryShape(QueryBuilder topLevelQueryBuilder) {
if (log.isTraceEnabled()) {
if (topLevelQueryBuilder == null) {
return;
}
QueryShapeVisitor shapeVisitor = new QueryShapeVisitor();
topLevelQueryBuilder.visit(shapeVisitor);
log.trace("Query shape : {}", shapeVisitor.prettyPrintTree(" "));
}
}

/**
* Get search query counters
* @return search query counters
Expand Down
Loading

0 comments on commit a40fad5

Please sign in to comment.