From d1b6c97014e6e80c35b023ce097d627631663dcf Mon Sep 17 00:00:00 2001 From: David Zane Date: Thu, 1 Aug 2024 15:50:06 -0700 Subject: [PATCH] Add query shape hash code Signed-off-by: David Zane --- .../categorizer/QueryShapeGenerator.java | 11 ++ .../insights/SearchSourceBuilderUtils.java | 108 +++++++++++++++++ .../categorizor/QueryShapeGeneratorTests.java | 113 +++++------------- 3 files changed, 149 insertions(+), 83 deletions(-) create mode 100644 src/test/java/org/opensearch/plugin/insights/SearchSourceBuilderUtils.java diff --git a/src/main/java/org/opensearch/plugin/insights/core/service/categorizer/QueryShapeGenerator.java b/src/main/java/org/opensearch/plugin/insights/core/service/categorizer/QueryShapeGenerator.java index b5eff30..f2e2868 100644 --- a/src/main/java/org/opensearch/plugin/insights/core/service/categorizer/QueryShapeGenerator.java +++ b/src/main/java/org/opensearch/plugin/insights/core/service/categorizer/QueryShapeGenerator.java @@ -76,6 +76,17 @@ public class QueryShapeGenerator { static final Map, List>> AGG_FIELD_DATA_MAP = FieldDataMapHelper.getAggFieldDataMap(); static final Map, List>> SORT_FIELD_DATA_MAP = FieldDataMapHelper.getSortFieldDataMap(); + /** + * Method to get query shape hash code given a source + * @param source search request source + * @param showFields whether to include field data in query shape + * @return Hash code of query shape as Integer + */ + public static int getShapeHashCode(SearchSourceBuilder source, Boolean showFields) { + String shape = buildShape(source, showFields); + return shape.hashCode(); + } + /** * Method to build search query shape given a source * @param source search request source diff --git a/src/test/java/org/opensearch/plugin/insights/SearchSourceBuilderUtils.java b/src/test/java/org/opensearch/plugin/insights/SearchSourceBuilderUtils.java new file mode 100644 index 0000000..747e212 --- /dev/null +++ b/src/test/java/org/opensearch/plugin/insights/SearchSourceBuilderUtils.java @@ -0,0 +1,108 @@ +package org.opensearch.plugin.insights; + +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.MatchQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.RangeQueryBuilder; +import org.opensearch.index.query.RegexpQueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.search.aggregations.bucket.terms.SignificantTextAggregationBuilder; +import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; +import org.opensearch.search.aggregations.metrics.TopHitsAggregationBuilder; +import org.opensearch.search.aggregations.pipeline.AvgBucketPipelineAggregationBuilder; +import org.opensearch.search.aggregations.pipeline.DerivativePipelineAggregationBuilder; +import org.opensearch.search.aggregations.pipeline.MaxBucketPipelineAggregationBuilder; +import org.opensearch.search.aggregations.support.ValueType; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.sort.SortOrder; + +public class SearchSourceBuilderUtils { + + public static SearchSourceBuilder createDefaultSearchSourceBuilder() { + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.size(0); + // build query + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery("field1", "value2"); + MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery("field2", "php"); + RegexpQueryBuilder regexpQueryBuilder = new RegexpQueryBuilder("field3", "text"); + RangeQueryBuilder rangeQueryBuilder = new RangeQueryBuilder("field4"); + sourceBuilder.query( + new BoolQueryBuilder().must(termQueryBuilder).filter(matchQueryBuilder).should(regexpQueryBuilder).filter(rangeQueryBuilder) + ); + // build aggregation + sourceBuilder.aggregation( + new TermsAggregationBuilder("agg1").userValueTypeHint(ValueType.STRING) + .field("type") + .subAggregation(new DerivativePipelineAggregationBuilder("pipeline-agg1", "bucket1")) + .subAggregation(new TermsAggregationBuilder("child-agg3").userValueTypeHint(ValueType.STRING).field("key.sub3")) + ); + sourceBuilder.aggregation(new TermsAggregationBuilder("agg2").userValueTypeHint(ValueType.STRING).field("model")); + sourceBuilder.aggregation( + new TermsAggregationBuilder("agg3").userValueTypeHint(ValueType.STRING) + .field("key") + .subAggregation(new MaxBucketPipelineAggregationBuilder("pipeline-agg2", "bucket2")) + .subAggregation(new TermsAggregationBuilder("child-agg1").userValueTypeHint(ValueType.STRING).field("key.sub1")) + .subAggregation(new TermsAggregationBuilder("child-agg2").userValueTypeHint(ValueType.STRING).field("key.sub2")) + ); + sourceBuilder.aggregation(new TopHitsAggregationBuilder("top_hits").storedField("_none_")); + sourceBuilder.aggregation(new SignificantTextAggregationBuilder("sig_text", "agg4").filterDuplicateText(true)); + sourceBuilder.aggregation(new MaxBucketPipelineAggregationBuilder("pipeline-agg4", "bucket4")); + sourceBuilder.aggregation(new DerivativePipelineAggregationBuilder("pipeline-agg3", "bucket3")); + sourceBuilder.aggregation(new AvgBucketPipelineAggregationBuilder("pipeline-agg5", "bucket5")); + // build sort + sourceBuilder.sort("color", SortOrder.DESC); + sourceBuilder.sort("vendor", SortOrder.DESC); + sourceBuilder.sort("price", SortOrder.ASC); + sourceBuilder.sort("album", SortOrder.ASC); + + return sourceBuilder; + } + + public static SearchSourceBuilder createQuerySearchSourceBuilder() { + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.size(0); + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery("field1", "value2"); + MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery("field2", "php"); + RegexpQueryBuilder regexpQueryBuilder = new RegexpQueryBuilder("field3", "text"); + RangeQueryBuilder rangeQueryBuilder = new RangeQueryBuilder("field4"); + sourceBuilder.query( + new BoolQueryBuilder().must(termQueryBuilder).filter(matchQueryBuilder).should(regexpQueryBuilder).filter(rangeQueryBuilder) + ); + return sourceBuilder; + } + + public static SearchSourceBuilder createAggregationSearchSourceBuilder() { + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + + sourceBuilder.aggregation( + new TermsAggregationBuilder("agg1").userValueTypeHint(ValueType.STRING) + .field("type") + .subAggregation(new DerivativePipelineAggregationBuilder("pipeline-agg1", "bucket1")) + .subAggregation(new TermsAggregationBuilder("child-agg3").userValueTypeHint(ValueType.STRING).field("key.sub3")) + ); + sourceBuilder.aggregation(new TermsAggregationBuilder("agg2").userValueTypeHint(ValueType.STRING).field("model")); + sourceBuilder.aggregation( + new TermsAggregationBuilder("agg3").userValueTypeHint(ValueType.STRING) + .field("key") + .subAggregation(new MaxBucketPipelineAggregationBuilder("pipeline-agg2", "bucket2")) + .subAggregation(new TermsAggregationBuilder("child-agg1").userValueTypeHint(ValueType.STRING).field("key.sub1")) + .subAggregation(new TermsAggregationBuilder("child-agg2").userValueTypeHint(ValueType.STRING).field("key.sub2")) + ); + sourceBuilder.aggregation(new TopHitsAggregationBuilder("top_hits").storedField("_none_")); + sourceBuilder.aggregation(new SignificantTextAggregationBuilder("sig_text", "agg4").filterDuplicateText(true)); + sourceBuilder.aggregation(new MaxBucketPipelineAggregationBuilder("pipeline-agg4", "bucket4")); + sourceBuilder.aggregation(new DerivativePipelineAggregationBuilder("pipeline-agg3", "bucket3")); + sourceBuilder.aggregation(new AvgBucketPipelineAggregationBuilder("pipeline-agg5", "bucket5")); + + return sourceBuilder; + } + + public static SearchSourceBuilder createSortSearchSourceBuilder() { + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.sort("color", SortOrder.DESC); + sourceBuilder.sort("vendor", SortOrder.DESC); + sourceBuilder.sort("price", SortOrder.ASC); + sourceBuilder.sort("album", SortOrder.ASC); + return sourceBuilder; + } +} diff --git a/src/test/java/org/opensearch/plugin/insights/core/service/categorizor/QueryShapeGeneratorTests.java b/src/test/java/org/opensearch/plugin/insights/core/service/categorizor/QueryShapeGeneratorTests.java index a0320a1..f0903bb 100644 --- a/src/test/java/org/opensearch/plugin/insights/core/service/categorizor/QueryShapeGeneratorTests.java +++ b/src/test/java/org/opensearch/plugin/insights/core/service/categorizor/QueryShapeGeneratorTests.java @@ -8,61 +8,15 @@ package org.opensearch.plugin.insights.core.service.categorizor; -import org.opensearch.index.query.BoolQueryBuilder; -import org.opensearch.index.query.MatchQueryBuilder; -import org.opensearch.index.query.QueryBuilders; -import org.opensearch.index.query.RangeQueryBuilder; -import org.opensearch.index.query.RegexpQueryBuilder; -import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.plugin.insights.SearchSourceBuilderUtils; import org.opensearch.plugin.insights.core.service.categorizer.QueryShapeGenerator; -import org.opensearch.search.aggregations.bucket.terms.SignificantTextAggregationBuilder; -import org.opensearch.search.aggregations.bucket.terms.TermsAggregationBuilder; -import org.opensearch.search.aggregations.metrics.TopHitsAggregationBuilder; -import org.opensearch.search.aggregations.pipeline.AvgBucketPipelineAggregationBuilder; -import org.opensearch.search.aggregations.pipeline.DerivativePipelineAggregationBuilder; -import org.opensearch.search.aggregations.pipeline.MaxBucketPipelineAggregationBuilder; -import org.opensearch.search.aggregations.support.ValueType; import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.search.sort.SortOrder; import org.opensearch.test.OpenSearchTestCase; public final class QueryShapeGeneratorTests extends OpenSearchTestCase { + public void testComplexSearch() { - SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); - sourceBuilder.size(0); - // build query - TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery("field1", "value2"); - MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery("field2", "php"); - RegexpQueryBuilder regexpQueryBuilder = new RegexpQueryBuilder("field3", "text"); - RangeQueryBuilder rangeQueryBuilder = new RangeQueryBuilder("field4"); - sourceBuilder.query( - new BoolQueryBuilder().must(termQueryBuilder).filter(matchQueryBuilder).should(regexpQueryBuilder).filter(rangeQueryBuilder) - ); - // build agg - sourceBuilder.aggregation( - new TermsAggregationBuilder("agg1").userValueTypeHint(ValueType.STRING) - .field("type") - .subAggregation(new DerivativePipelineAggregationBuilder("pipeline-agg1", "bucket1")) - .subAggregation(new TermsAggregationBuilder("child-agg3").userValueTypeHint(ValueType.STRING).field("key.sub3")) - ); - sourceBuilder.aggregation(new TermsAggregationBuilder("agg2").userValueTypeHint(ValueType.STRING).field("model")); - sourceBuilder.aggregation( - new TermsAggregationBuilder("agg3").userValueTypeHint(ValueType.STRING) - .field("key") - .subAggregation(new MaxBucketPipelineAggregationBuilder("pipeline-agg2", "bucket2")) - .subAggregation(new TermsAggregationBuilder("child-agg1").userValueTypeHint(ValueType.STRING).field("key.sub1")) - .subAggregation(new TermsAggregationBuilder("child-agg2").userValueTypeHint(ValueType.STRING).field("key.sub2")) - ); - sourceBuilder.aggregation(new TopHitsAggregationBuilder("top_hits").storedField("_none_")); - sourceBuilder.aggregation(new SignificantTextAggregationBuilder("sig_text", "agg4").filterDuplicateText(true)); - sourceBuilder.aggregation(new MaxBucketPipelineAggregationBuilder("pipeline-agg4", "bucket4")); - sourceBuilder.aggregation(new DerivativePipelineAggregationBuilder("pipeline-agg3", "bucket3")); - sourceBuilder.aggregation(new AvgBucketPipelineAggregationBuilder("pipeline-agg5", "bucket5")); - // build sort - sourceBuilder.sort("color", SortOrder.DESC); - sourceBuilder.sort("vendor", SortOrder.DESC); - sourceBuilder.sort("price", SortOrder.ASC); - sourceBuilder.sort("album", SortOrder.ASC); + SearchSourceBuilder sourceBuilder = SearchSourceBuilderUtils.createDefaultSearchSourceBuilder(); String shapeShowFieldsTrue = QueryShapeGenerator.buildShape(sourceBuilder, true); String expectedShowFieldsTrue = "bool []\n" @@ -136,15 +90,7 @@ public void testComplexSearch() { } public void testQueryShape() { - SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); - sourceBuilder.size(0); - TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery("field1", "value2"); - MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery("field2", "php"); - RegexpQueryBuilder regexpQueryBuilder = new RegexpQueryBuilder("field3", "text"); - RangeQueryBuilder rangeQueryBuilder = new RangeQueryBuilder("field4"); - sourceBuilder.query( - new BoolQueryBuilder().must(termQueryBuilder).filter(matchQueryBuilder).should(regexpQueryBuilder).filter(rangeQueryBuilder) - ); + SearchSourceBuilder sourceBuilder = SearchSourceBuilderUtils.createQuerySearchSourceBuilder(); String shapeShowFieldsTrue = QueryShapeGenerator.buildShape(sourceBuilder, true); String expectedShowFieldsTrue = "bool []\n" @@ -170,26 +116,7 @@ public void testQueryShape() { } public void testAggregationShape() { - SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); - sourceBuilder.aggregation( - new TermsAggregationBuilder("agg1").userValueTypeHint(ValueType.STRING) - .field("type") - .subAggregation(new DerivativePipelineAggregationBuilder("pipeline-agg1", "bucket1")) - .subAggregation(new TermsAggregationBuilder("child-agg3").userValueTypeHint(ValueType.STRING).field("key.sub3")) - ); - sourceBuilder.aggregation(new TermsAggregationBuilder("agg2").userValueTypeHint(ValueType.STRING).field("model")); - sourceBuilder.aggregation( - new TermsAggregationBuilder("agg3").userValueTypeHint(ValueType.STRING) - .field("key") - .subAggregation(new MaxBucketPipelineAggregationBuilder("pipeline-agg2", "bucket2")) - .subAggregation(new TermsAggregationBuilder("child-agg1").userValueTypeHint(ValueType.STRING).field("key.sub1")) - .subAggregation(new TermsAggregationBuilder("child-agg2").userValueTypeHint(ValueType.STRING).field("key.sub2")) - ); - sourceBuilder.aggregation(new TopHitsAggregationBuilder("top_hits").storedField("_none_")); - sourceBuilder.aggregation(new SignificantTextAggregationBuilder("sig_text", "agg4").filterDuplicateText(true)); - sourceBuilder.aggregation(new MaxBucketPipelineAggregationBuilder("pipeline-agg4", "bucket4")); - sourceBuilder.aggregation(new DerivativePipelineAggregationBuilder("pipeline-agg3", "bucket3")); - sourceBuilder.aggregation(new AvgBucketPipelineAggregationBuilder("pipeline-agg5", "bucket5")); + SearchSourceBuilder sourceBuilder = SearchSourceBuilderUtils.createAggregationSearchSourceBuilder(); String shapeShowFieldsTrue = QueryShapeGenerator.buildShape(sourceBuilder, true); String expectedShowFieldsTrue = "aggregation:\n" @@ -237,11 +164,7 @@ public void testAggregationShape() { } public void testSortShape() { - SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); - sourceBuilder.sort("color", SortOrder.DESC); - sourceBuilder.sort("vendor", SortOrder.DESC); - sourceBuilder.sort("price", SortOrder.ASC); - sourceBuilder.sort("album", SortOrder.ASC); + SearchSourceBuilder sourceBuilder = SearchSourceBuilderUtils.createSortSearchSourceBuilder(); String shapeShowFieldsTrue = QueryShapeGenerator.buildShape(sourceBuilder, true); String expectedShowFieldsTrue = "sort:\n" + " asc [album]\n" + " asc [price]\n" + " desc [color]\n" + " desc [vendor]\n"; @@ -251,4 +174,28 @@ public void testSortShape() { String expectedShowFieldsFalse = "sort:\n" + " asc\n" + " asc\n" + " desc\n" + " desc\n"; assertEquals(expectedShowFieldsFalse, shapeShowFieldsFalse); } + + public void testHashCode() { + // Create test source builders + SearchSourceBuilder defaultSourceBuilder = SearchSourceBuilderUtils.createDefaultSearchSourceBuilder(); + SearchSourceBuilder querySourceBuilder = SearchSourceBuilderUtils.createQuerySearchSourceBuilder(); + + // showFields = true + int defaultHashTrue = QueryShapeGenerator.getShapeHashCode(defaultSourceBuilder, true); + int queryHashTrue = QueryShapeGenerator.getShapeHashCode(querySourceBuilder, true); + assertEquals(defaultHashTrue, QueryShapeGenerator.getShapeHashCode(defaultSourceBuilder, true)); + assertEquals(queryHashTrue, QueryShapeGenerator.getShapeHashCode(querySourceBuilder, true)); + assertNotEquals(defaultHashTrue, queryHashTrue); + + // showFields = false + int defaultHashFalse = QueryShapeGenerator.getShapeHashCode(defaultSourceBuilder, false); + int queryHashFalse = QueryShapeGenerator.getShapeHashCode(querySourceBuilder, false); + assertEquals(defaultHashFalse, QueryShapeGenerator.getShapeHashCode(defaultSourceBuilder, false)); + assertEquals(queryHashFalse, QueryShapeGenerator.getShapeHashCode(querySourceBuilder, false)); + assertNotEquals(defaultHashFalse, queryHashFalse); + + // Compare field data on vs off + assertNotEquals(defaultHashTrue, defaultHashFalse); + assertNotEquals(queryHashTrue, queryHashFalse); + } }