diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java index d5528f541fb274..3682023f88a5cd 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java @@ -111,6 +111,7 @@ import io.trino.sql.planner.iterative.rule.PruneMergeSourceColumns; import io.trino.sql.planner.iterative.rule.PruneOffsetColumns; import io.trino.sql.planner.iterative.rule.PruneOrderByInAggregation; +import io.trino.sql.planner.iterative.rule.PruneOrderByInWindowAggregation; import io.trino.sql.planner.iterative.rule.PruneOutputSourceColumns; import io.trino.sql.planner.iterative.rule.PrunePattenRecognitionColumns; import io.trino.sql.planner.iterative.rule.PrunePatternRecognitionSourceColumns; @@ -460,6 +461,7 @@ public PlanOptimizers( new MergeLimitWithDistinct(), new PruneCountAggregationOverScalar(metadata), new PruneOrderByInAggregation(metadata), + new PruneOrderByInWindowAggregation(metadata), new RewriteSpatialPartitioningAggregation(plannerContext), new SimplifyCountOverConstant(plannerContext), new PreAggregateCaseAggregations(plannerContext), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneOrderByInWindowAggregation.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneOrderByInWindowAggregation.java new file mode 100644 index 00000000000000..0da210a5a46a59 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneOrderByInWindowAggregation.java @@ -0,0 +1,83 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableMap; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.metadata.Metadata; +import io.trino.spi.function.FunctionKind; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.iterative.Rule; +import io.trino.sql.planner.plan.WindowNode; +import io.trino.sql.planner.plan.WindowNode.Function; + +import java.util.Map; +import java.util.Optional; + +import static io.trino.sql.planner.plan.Patterns.window; +import static java.util.Objects.requireNonNull; + +public class PruneOrderByInWindowAggregation + implements Rule +{ + private static final Pattern PATTERN = window(); + private final Metadata metadata; + + public PruneOrderByInWindowAggregation(Metadata metadata) + { + this.metadata = requireNonNull(metadata, "metadata is null"); + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(WindowNode node, Captures captures, Context context) + { + boolean anyRewritten = false; + ImmutableMap.Builder rewritten = ImmutableMap.builder(); + for (Map.Entry entry : node.getWindowFunctions().entrySet()) { + Function function = entry.getValue(); + // getAggregateFunctionImplementation can be expensive, so check it last. + if (function.getOrderingScheme().isPresent() && + function.getResolvedFunction().functionKind() == FunctionKind.AGGREGATE && + !metadata.getAggregationFunctionMetadata(context.getSession(), function.getResolvedFunction()).isOrderSensitive()) { + function = new Function( + function.getResolvedFunction(), + function.getArguments(), + Optional.empty(), // prune + function.getFrame(), + function.isIgnoreNulls()); + anyRewritten = true; + } + rewritten.put(entry.getKey(), function); + } + + if (!anyRewritten) { + return Result.empty(); + } + return Result.ofPlanNode(new WindowNode( + node.getId(), + node.getSource(), + node.getSpecification(), + rewritten.buildOrThrow(), + node.getHashSymbol(), + node.getPrePartitionedInputs(), + node.getPreSortedOrderPrefix())); + } +} diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java index bbe26f05e3d07b..5605a0eb0d8df1 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java @@ -1105,7 +1105,12 @@ public static ExpectedValueProvider aggregationFunction(Str public static ExpectedValueProvider windowFunction(String name, List args, WindowNode.Frame frame) { - return new WindowFunctionProvider(name, frame, toSymbolAliases(args)); + return windowFunction(name, args, frame, ImmutableList.of()); + } + + public static ExpectedValueProvider windowFunction(String name, List args, WindowNode.Frame frame, List orderBy) + { + return new WindowFunctionProvider(name, frame, toSymbolAliases(args), orderBy); } public static List toSymbolReferences(List aliases, SymbolAliases symbolAliases) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/WindowFunction.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/WindowFunction.java index 19519c719a9a33..f3327b07fcb635 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/WindowFunction.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/WindowFunction.java @@ -15,21 +15,25 @@ import com.google.common.collect.ImmutableList; import io.trino.sql.ir.Expression; +import io.trino.sql.planner.OrderingScheme; import io.trino.sql.planner.plan.WindowNode; import java.util.List; +import java.util.Optional; import static java.util.Objects.requireNonNull; public record WindowFunction( String name, WindowNode.Frame frame, - List arguments) + List arguments, + Optional orderBy) { - public WindowFunction(String name, WindowNode.Frame frame, List arguments) + public WindowFunction(String name, WindowNode.Frame frame, List arguments, Optional orderBy) { this.name = requireNonNull(name, "name is null"); this.frame = requireNonNull(frame, "frame is null"); this.arguments = ImmutableList.copyOf(arguments); + this.orderBy = requireNonNull(orderBy, "orderBy is null"); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/WindowFunctionMatcher.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/WindowFunctionMatcher.java index bcfad93e4cc9e9..de5f1d194f54f2 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/WindowFunctionMatcher.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/WindowFunctionMatcher.java @@ -22,6 +22,7 @@ import io.trino.sql.planner.plan.WindowNode.Function; import java.util.Map; +import java.util.Objects; import java.util.Optional; import static com.google.common.base.MoreObjects.toStringHelper; @@ -70,6 +71,7 @@ private boolean windowFunctionMatches(Function windowFunction, WindowFunction ex { return expectedCall.name().equals(windowFunction.getResolvedFunction().signature().getName().getFunctionName()) && WindowFrameMatcher.matches(expectedCall.frame(), windowFunction.getFrame(), aliases) && + Objects.equals(expectedCall.orderBy(), windowFunction.getOrderingScheme()) && expectedCall.arguments().equals(windowFunction.getArguments()); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/WindowFunctionProvider.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/WindowFunctionProvider.java index 63710d4531e476..5ac2dd2f485610 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/WindowFunctionProvider.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/WindowFunctionProvider.java @@ -14,9 +14,16 @@ package io.trino.sql.planner.assertions; import com.google.common.base.Joiner; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.spi.connector.SortOrder; +import io.trino.sql.ir.Reference; +import io.trino.sql.planner.OrderingScheme; +import io.trino.sql.planner.Symbol; import io.trino.sql.planner.plan.WindowNode; import java.util.List; +import java.util.Optional; import static io.trino.sql.planner.assertions.PlanMatchPattern.toSymbolReferences; import static java.util.Objects.requireNonNull; @@ -27,26 +34,43 @@ final class WindowFunctionProvider private final String name; private final WindowNode.Frame frame; private final List args; + private final List orderBy; - public WindowFunctionProvider(String name, WindowNode.Frame frame, List args) + public WindowFunctionProvider(String name, WindowNode.Frame frame, List args, List orderBy) { this.name = requireNonNull(name, "name is null"); this.frame = requireNonNull(frame, "frame is null"); this.args = requireNonNull(args, "args is null"); + this.orderBy = ImmutableList.copyOf(orderBy); } @Override public String toString() { - return "%s(%s) %s".formatted( + return "%s(%s%s) %s".formatted( name, Joiner.on(", ").join(args), + orderBy.isEmpty() ? "" : " ORDER BY " + Joiner.on(", ").join(orderBy), frame); } @Override public WindowFunction getExpectedValue(SymbolAliases aliases) { - return new WindowFunction(name, frame, toSymbolReferences(args, aliases)); + Optional orderByClause = Optional.empty(); + if (!orderBy.isEmpty()) { + ImmutableList.Builder fields = ImmutableList.builder(); + ImmutableMap.Builder orders = ImmutableMap.builder(); + + for (PlanMatchPattern.Ordering ordering : this.orderBy) { + Reference reference = aliases.get(ordering.getField()); + Symbol symbol = new Symbol(reference.type(), reference.name()); + fields.add(symbol); + orders.put(symbol, ordering.getSortOrder()); + } + orderByClause = Optional.of(new OrderingScheme(fields.build(), orders.buildOrThrow())); + } + + return new WindowFunction(name, frame, toSymbolReferences(args, aliases), orderByClause); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneOrderByInWindowAggregation.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneOrderByInWindowAggregation.java new file mode 100644 index 00000000000000..f5f756a169fdf9 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneOrderByInWindowAggregation.java @@ -0,0 +1,103 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.metadata.Metadata; +import io.trino.metadata.ResolvedFunction; +import io.trino.spi.connector.SortOrder; +import io.trino.sql.ir.Reference; +import io.trino.sql.planner.OrderingScheme; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; +import io.trino.sql.planner.iterative.rule.test.PlanBuilder; +import io.trino.sql.planner.plan.DataOrganizationSpecification; +import io.trino.sql.planner.plan.WindowNode; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Optional; + +import static io.trino.metadata.TestMetadataManager.createTestMetadataManager; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; +import static io.trino.sql.planner.assertions.PlanMatchPattern.sort; +import static io.trino.sql.planner.assertions.PlanMatchPattern.specification; +import static io.trino.sql.planner.assertions.PlanMatchPattern.values; +import static io.trino.sql.planner.assertions.PlanMatchPattern.window; +import static io.trino.sql.planner.assertions.PlanMatchPattern.windowFunction; +import static io.trino.sql.planner.plan.WindowNode.Frame.DEFAULT_FRAME; +import static io.trino.sql.tree.SortItem.NullOrdering.LAST; +import static io.trino.sql.tree.SortItem.Ordering.ASCENDING; + +public class TestPruneOrderByInWindowAggregation + extends BaseRuleTest +{ + private static final Metadata METADATA = createTestMetadataManager(); + + @Test + public void testBasics() + { + tester().assertThat(new PruneOrderByInWindowAggregation(METADATA)) + .on(this::buildWindowNode) + .matches( + window( + windowMatcherBuilder -> windowMatcherBuilder + .specification(specification( + ImmutableList.of("key"), + ImmutableList.of(), + ImmutableMap.of())) + .addFunction( + "avg", + windowFunction("avg", ImmutableList.of("input"), DEFAULT_FRAME)) + .addFunction( + "array_agg", + windowFunction("array_agg", ImmutableList.of("input"), DEFAULT_FRAME, ImmutableList.of(sort("input", ASCENDING, LAST)))), + values("input", "key", "keyHash", "mask"))); + } + + private WindowNode buildWindowNode(PlanBuilder planBuilder) + { + Symbol avg = planBuilder.symbol("avg"); + Symbol arrayAgg = planBuilder.symbol("araray_agg"); + Symbol input = planBuilder.symbol("input"); + Symbol key = planBuilder.symbol("key"); + Symbol keyHash = planBuilder.symbol("keyHash"); + Symbol mask = planBuilder.symbol("mask"); + List sourceSymbols = ImmutableList.of(input, key, keyHash, mask); + + ResolvedFunction avgFunction = createTestMetadataManager().resolveBuiltinFunction("avg", fromTypes(BIGINT)); + ResolvedFunction arrayAggFunction = createTestMetadataManager().resolveBuiltinFunction("array_agg", fromTypes(BIGINT)); + + return planBuilder.window( + new DataOrganizationSpecification(ImmutableList.of(planBuilder.symbol("key", BIGINT)), Optional.empty()), + ImmutableMap.of( + avg, new WindowNode.Function(avgFunction, + ImmutableList.of(new Reference(BIGINT, "input")), + Optional.of(new OrderingScheme( + ImmutableList.of(new Symbol(BIGINT, "input")), + ImmutableMap.of(new Symbol(BIGINT, "input"), SortOrder.ASC_NULLS_LAST))), + DEFAULT_FRAME, + false), + arrayAgg, new WindowNode.Function(arrayAggFunction, + ImmutableList.of(new Reference(BIGINT, "input")), + Optional.of(new OrderingScheme( + ImmutableList.of(new Symbol(BIGINT, "input")), + ImmutableMap.of(new Symbol(BIGINT, "input"), SortOrder.ASC_NULLS_LAST))), + DEFAULT_FRAME, + false)), + planBuilder.values(sourceSymbols, ImmutableList.of())); + } +}