Skip to content

Commit

Permalink
Add query shape hash code
Browse files Browse the repository at this point in the history
Signed-off-by: David Zane <davizane@amazon.com>
  • Loading branch information
dzane17 committed Aug 1, 2024
1 parent b55d760 commit d1b6c97
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,17 @@ public class QueryShapeGenerator {
static final Map<Class<?>, List<Function<Object, String>>> AGG_FIELD_DATA_MAP = FieldDataMapHelper.getAggFieldDataMap();
static final Map<Class<?>, List<Function<Object, String>>> SORT_FIELD_DATA_MAP = FieldDataMapHelper.getSortFieldDataMap();

/**
* Method to get query shape hash code given a source
* @param source search request source
* @param showFields whether to include field data in query shape
* @return Hash code of query shape as Integer
*/
public static int getShapeHashCode(SearchSourceBuilder source, Boolean showFields) {
String shape = buildShape(source, showFields);
return shape.hashCode();
}

/**
* Method to build search query shape given a source
* @param source search request source
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package org.opensearch.plugin.insights;

import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.MatchQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.RangeQueryBuilder;
import org.opensearch.index.query.RegexpQueryBuilder;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.search.aggregations.bucket.terms.SignificantTextAggregationBuilder;
import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.opensearch.search.aggregations.metrics.TopHitsAggregationBuilder;
import org.opensearch.search.aggregations.pipeline.AvgBucketPipelineAggregationBuilder;
import org.opensearch.search.aggregations.pipeline.DerivativePipelineAggregationBuilder;
import org.opensearch.search.aggregations.pipeline.MaxBucketPipelineAggregationBuilder;
import org.opensearch.search.aggregations.support.ValueType;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.sort.SortOrder;

public class SearchSourceBuilderUtils {

public static SearchSourceBuilder createDefaultSearchSourceBuilder() {
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
sourceBuilder.size(0);
// build query
TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery("field1", "value2");
MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery("field2", "php");
RegexpQueryBuilder regexpQueryBuilder = new RegexpQueryBuilder("field3", "text");
RangeQueryBuilder rangeQueryBuilder = new RangeQueryBuilder("field4");
sourceBuilder.query(
new BoolQueryBuilder().must(termQueryBuilder).filter(matchQueryBuilder).should(regexpQueryBuilder).filter(rangeQueryBuilder)
);
// build aggregation
sourceBuilder.aggregation(
new TermsAggregationBuilder("agg1").userValueTypeHint(ValueType.STRING)
.field("type")
.subAggregation(new DerivativePipelineAggregationBuilder("pipeline-agg1", "bucket1"))
.subAggregation(new TermsAggregationBuilder("child-agg3").userValueTypeHint(ValueType.STRING).field("key.sub3"))
);
sourceBuilder.aggregation(new TermsAggregationBuilder("agg2").userValueTypeHint(ValueType.STRING).field("model"));
sourceBuilder.aggregation(
new TermsAggregationBuilder("agg3").userValueTypeHint(ValueType.STRING)
.field("key")
.subAggregation(new MaxBucketPipelineAggregationBuilder("pipeline-agg2", "bucket2"))
.subAggregation(new TermsAggregationBuilder("child-agg1").userValueTypeHint(ValueType.STRING).field("key.sub1"))
.subAggregation(new TermsAggregationBuilder("child-agg2").userValueTypeHint(ValueType.STRING).field("key.sub2"))
);
sourceBuilder.aggregation(new TopHitsAggregationBuilder("top_hits").storedField("_none_"));
sourceBuilder.aggregation(new SignificantTextAggregationBuilder("sig_text", "agg4").filterDuplicateText(true));
sourceBuilder.aggregation(new MaxBucketPipelineAggregationBuilder("pipeline-agg4", "bucket4"));
sourceBuilder.aggregation(new DerivativePipelineAggregationBuilder("pipeline-agg3", "bucket3"));
sourceBuilder.aggregation(new AvgBucketPipelineAggregationBuilder("pipeline-agg5", "bucket5"));
// build sort
sourceBuilder.sort("color", SortOrder.DESC);
sourceBuilder.sort("vendor", SortOrder.DESC);
sourceBuilder.sort("price", SortOrder.ASC);
sourceBuilder.sort("album", SortOrder.ASC);

return sourceBuilder;
}

public static SearchSourceBuilder createQuerySearchSourceBuilder() {
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
sourceBuilder.size(0);
TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery("field1", "value2");
MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery("field2", "php");
RegexpQueryBuilder regexpQueryBuilder = new RegexpQueryBuilder("field3", "text");
RangeQueryBuilder rangeQueryBuilder = new RangeQueryBuilder("field4");
sourceBuilder.query(
new BoolQueryBuilder().must(termQueryBuilder).filter(matchQueryBuilder).should(regexpQueryBuilder).filter(rangeQueryBuilder)
);
return sourceBuilder;
}

public static SearchSourceBuilder createAggregationSearchSourceBuilder() {
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();

sourceBuilder.aggregation(
new TermsAggregationBuilder("agg1").userValueTypeHint(ValueType.STRING)
.field("type")
.subAggregation(new DerivativePipelineAggregationBuilder("pipeline-agg1", "bucket1"))
.subAggregation(new TermsAggregationBuilder("child-agg3").userValueTypeHint(ValueType.STRING).field("key.sub3"))
);
sourceBuilder.aggregation(new TermsAggregationBuilder("agg2").userValueTypeHint(ValueType.STRING).field("model"));
sourceBuilder.aggregation(
new TermsAggregationBuilder("agg3").userValueTypeHint(ValueType.STRING)
.field("key")
.subAggregation(new MaxBucketPipelineAggregationBuilder("pipeline-agg2", "bucket2"))
.subAggregation(new TermsAggregationBuilder("child-agg1").userValueTypeHint(ValueType.STRING).field("key.sub1"))
.subAggregation(new TermsAggregationBuilder("child-agg2").userValueTypeHint(ValueType.STRING).field("key.sub2"))
);
sourceBuilder.aggregation(new TopHitsAggregationBuilder("top_hits").storedField("_none_"));
sourceBuilder.aggregation(new SignificantTextAggregationBuilder("sig_text", "agg4").filterDuplicateText(true));
sourceBuilder.aggregation(new MaxBucketPipelineAggregationBuilder("pipeline-agg4", "bucket4"));
sourceBuilder.aggregation(new DerivativePipelineAggregationBuilder("pipeline-agg3", "bucket3"));
sourceBuilder.aggregation(new AvgBucketPipelineAggregationBuilder("pipeline-agg5", "bucket5"));

return sourceBuilder;
}

public static SearchSourceBuilder createSortSearchSourceBuilder() {
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
sourceBuilder.sort("color", SortOrder.DESC);
sourceBuilder.sort("vendor", SortOrder.DESC);
sourceBuilder.sort("price", SortOrder.ASC);
sourceBuilder.sort("album", SortOrder.ASC);
return sourceBuilder;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,61 +8,15 @@

package org.opensearch.plugin.insights.core.service.categorizor;

import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.MatchQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.RangeQueryBuilder;
import org.opensearch.index.query.RegexpQueryBuilder;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.plugin.insights.SearchSourceBuilderUtils;
import org.opensearch.plugin.insights.core.service.categorizer.QueryShapeGenerator;
import org.opensearch.search.aggregations.bucket.terms.SignificantTextAggregationBuilder;
import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.opensearch.search.aggregations.metrics.TopHitsAggregationBuilder;
import org.opensearch.search.aggregations.pipeline.AvgBucketPipelineAggregationBuilder;
import org.opensearch.search.aggregations.pipeline.DerivativePipelineAggregationBuilder;
import org.opensearch.search.aggregations.pipeline.MaxBucketPipelineAggregationBuilder;
import org.opensearch.search.aggregations.support.ValueType;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.sort.SortOrder;
import org.opensearch.test.OpenSearchTestCase;

public final class QueryShapeGeneratorTests extends OpenSearchTestCase {

public void testComplexSearch() {
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
sourceBuilder.size(0);
// build query
TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery("field1", "value2");
MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery("field2", "php");
RegexpQueryBuilder regexpQueryBuilder = new RegexpQueryBuilder("field3", "text");
RangeQueryBuilder rangeQueryBuilder = new RangeQueryBuilder("field4");
sourceBuilder.query(
new BoolQueryBuilder().must(termQueryBuilder).filter(matchQueryBuilder).should(regexpQueryBuilder).filter(rangeQueryBuilder)
);
// build agg
sourceBuilder.aggregation(
new TermsAggregationBuilder("agg1").userValueTypeHint(ValueType.STRING)
.field("type")
.subAggregation(new DerivativePipelineAggregationBuilder("pipeline-agg1", "bucket1"))
.subAggregation(new TermsAggregationBuilder("child-agg3").userValueTypeHint(ValueType.STRING).field("key.sub3"))
);
sourceBuilder.aggregation(new TermsAggregationBuilder("agg2").userValueTypeHint(ValueType.STRING).field("model"));
sourceBuilder.aggregation(
new TermsAggregationBuilder("agg3").userValueTypeHint(ValueType.STRING)
.field("key")
.subAggregation(new MaxBucketPipelineAggregationBuilder("pipeline-agg2", "bucket2"))
.subAggregation(new TermsAggregationBuilder("child-agg1").userValueTypeHint(ValueType.STRING).field("key.sub1"))
.subAggregation(new TermsAggregationBuilder("child-agg2").userValueTypeHint(ValueType.STRING).field("key.sub2"))
);
sourceBuilder.aggregation(new TopHitsAggregationBuilder("top_hits").storedField("_none_"));
sourceBuilder.aggregation(new SignificantTextAggregationBuilder("sig_text", "agg4").filterDuplicateText(true));
sourceBuilder.aggregation(new MaxBucketPipelineAggregationBuilder("pipeline-agg4", "bucket4"));
sourceBuilder.aggregation(new DerivativePipelineAggregationBuilder("pipeline-agg3", "bucket3"));
sourceBuilder.aggregation(new AvgBucketPipelineAggregationBuilder("pipeline-agg5", "bucket5"));
// build sort
sourceBuilder.sort("color", SortOrder.DESC);
sourceBuilder.sort("vendor", SortOrder.DESC);
sourceBuilder.sort("price", SortOrder.ASC);
sourceBuilder.sort("album", SortOrder.ASC);
SearchSourceBuilder sourceBuilder = SearchSourceBuilderUtils.createDefaultSearchSourceBuilder();

String shapeShowFieldsTrue = QueryShapeGenerator.buildShape(sourceBuilder, true);
String expectedShowFieldsTrue = "bool []\n"
Expand Down Expand Up @@ -136,15 +90,7 @@ public void testComplexSearch() {
}

public void testQueryShape() {
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
sourceBuilder.size(0);
TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery("field1", "value2");
MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery("field2", "php");
RegexpQueryBuilder regexpQueryBuilder = new RegexpQueryBuilder("field3", "text");
RangeQueryBuilder rangeQueryBuilder = new RangeQueryBuilder("field4");
sourceBuilder.query(
new BoolQueryBuilder().must(termQueryBuilder).filter(matchQueryBuilder).should(regexpQueryBuilder).filter(rangeQueryBuilder)
);
SearchSourceBuilder sourceBuilder = SearchSourceBuilderUtils.createQuerySearchSourceBuilder();

String shapeShowFieldsTrue = QueryShapeGenerator.buildShape(sourceBuilder, true);
String expectedShowFieldsTrue = "bool []\n"
Expand All @@ -170,26 +116,7 @@ public void testQueryShape() {
}

public void testAggregationShape() {
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
sourceBuilder.aggregation(
new TermsAggregationBuilder("agg1").userValueTypeHint(ValueType.STRING)
.field("type")
.subAggregation(new DerivativePipelineAggregationBuilder("pipeline-agg1", "bucket1"))
.subAggregation(new TermsAggregationBuilder("child-agg3").userValueTypeHint(ValueType.STRING).field("key.sub3"))
);
sourceBuilder.aggregation(new TermsAggregationBuilder("agg2").userValueTypeHint(ValueType.STRING).field("model"));
sourceBuilder.aggregation(
new TermsAggregationBuilder("agg3").userValueTypeHint(ValueType.STRING)
.field("key")
.subAggregation(new MaxBucketPipelineAggregationBuilder("pipeline-agg2", "bucket2"))
.subAggregation(new TermsAggregationBuilder("child-agg1").userValueTypeHint(ValueType.STRING).field("key.sub1"))
.subAggregation(new TermsAggregationBuilder("child-agg2").userValueTypeHint(ValueType.STRING).field("key.sub2"))
);
sourceBuilder.aggregation(new TopHitsAggregationBuilder("top_hits").storedField("_none_"));
sourceBuilder.aggregation(new SignificantTextAggregationBuilder("sig_text", "agg4").filterDuplicateText(true));
sourceBuilder.aggregation(new MaxBucketPipelineAggregationBuilder("pipeline-agg4", "bucket4"));
sourceBuilder.aggregation(new DerivativePipelineAggregationBuilder("pipeline-agg3", "bucket3"));
sourceBuilder.aggregation(new AvgBucketPipelineAggregationBuilder("pipeline-agg5", "bucket5"));
SearchSourceBuilder sourceBuilder = SearchSourceBuilderUtils.createAggregationSearchSourceBuilder();

String shapeShowFieldsTrue = QueryShapeGenerator.buildShape(sourceBuilder, true);
String expectedShowFieldsTrue = "aggregation:\n"
Expand Down Expand Up @@ -237,11 +164,7 @@ public void testAggregationShape() {
}

public void testSortShape() {
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
sourceBuilder.sort("color", SortOrder.DESC);
sourceBuilder.sort("vendor", SortOrder.DESC);
sourceBuilder.sort("price", SortOrder.ASC);
sourceBuilder.sort("album", SortOrder.ASC);
SearchSourceBuilder sourceBuilder = SearchSourceBuilderUtils.createSortSearchSourceBuilder();

String shapeShowFieldsTrue = QueryShapeGenerator.buildShape(sourceBuilder, true);
String expectedShowFieldsTrue = "sort:\n" + " asc [album]\n" + " asc [price]\n" + " desc [color]\n" + " desc [vendor]\n";
Expand All @@ -251,4 +174,28 @@ public void testSortShape() {
String expectedShowFieldsFalse = "sort:\n" + " asc\n" + " asc\n" + " desc\n" + " desc\n";
assertEquals(expectedShowFieldsFalse, shapeShowFieldsFalse);
}

public void testHashCode() {
// Create test source builders
SearchSourceBuilder defaultSourceBuilder = SearchSourceBuilderUtils.createDefaultSearchSourceBuilder();
SearchSourceBuilder querySourceBuilder = SearchSourceBuilderUtils.createQuerySearchSourceBuilder();

// showFields = true
int defaultHashTrue = QueryShapeGenerator.getShapeHashCode(defaultSourceBuilder, true);
int queryHashTrue = QueryShapeGenerator.getShapeHashCode(querySourceBuilder, true);
assertEquals(defaultHashTrue, QueryShapeGenerator.getShapeHashCode(defaultSourceBuilder, true));
assertEquals(queryHashTrue, QueryShapeGenerator.getShapeHashCode(querySourceBuilder, true));
assertNotEquals(defaultHashTrue, queryHashTrue);

// showFields = false
int defaultHashFalse = QueryShapeGenerator.getShapeHashCode(defaultSourceBuilder, false);
int queryHashFalse = QueryShapeGenerator.getShapeHashCode(querySourceBuilder, false);
assertEquals(defaultHashFalse, QueryShapeGenerator.getShapeHashCode(defaultSourceBuilder, false));
assertEquals(queryHashFalse, QueryShapeGenerator.getShapeHashCode(querySourceBuilder, false));
assertNotEquals(defaultHashFalse, queryHashFalse);

// Compare field data on vs off
assertNotEquals(defaultHashTrue, defaultHashFalse);
assertNotEquals(queryHashTrue, queryHashFalse);
}
}

0 comments on commit d1b6c97

Please sign in to comment.