diff --git a/core/trino-main/src/main/java/io/trino/operator/aggregation/OrderedWindowAccumulator.java b/core/trino-main/src/main/java/io/trino/operator/aggregation/OrderedWindowAccumulator.java new file mode 100644 index 00000000000000..53e65e12b4405b --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/aggregation/OrderedWindowAccumulator.java @@ -0,0 +1,142 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.aggregation; + +import com.google.common.collect.ImmutableList; +import io.trino.operator.PagesIndex; +import io.trino.operator.window.PagesWindowIndex; +import io.trino.spi.PageBuilder; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.ValueBlock; +import io.trino.spi.connector.SortOrder; +import io.trino.spi.function.WindowAccumulator; +import io.trino.spi.function.WindowIndex; +import io.trino.spi.type.Type; + +import java.util.List; +import java.util.stream.IntStream; + +import static com.google.common.base.Preconditions.checkState; +import static java.util.Objects.requireNonNull; + +public class OrderedWindowAccumulator + implements WindowAccumulator +{ + PagesIndex.Factory pagesIndexFactory; + private WindowAccumulator delegate; + private final WindowAccumulator initialDelegate; + private final List argumentTypes; + private final List argumentChannels; + private final List sortChannels; + private final List sortOrders; + + private PageBuilder pageBuilder; + private final PagesIndex pagesIndex; + private boolean pagesIndexSorted; + + public OrderedWindowAccumulator( + PagesIndex.Factory pagesIndexFactory, + WindowAccumulator delegate, + List argumentTypes, + List argumentChannels, + List sortOrders) + { + this(pagesIndexFactory, delegate, argumentTypes, argumentChannels, sortOrders, pagesIndexFactory.newPagesIndex(argumentTypes, 10_000)); + } + + private OrderedWindowAccumulator( + PagesIndex.Factory pagesIndexFactory, + WindowAccumulator delegate, + List argumentTypes, + List argumentChannels, + List sortOrders, + PagesIndex pagesIndex) + { + this.pagesIndexFactory = requireNonNull(pagesIndexFactory, "pagesIndexFactory is null"); + this.delegate = requireNonNull(delegate, "delegate is null"); + this.initialDelegate = delegate.copy(); + this.argumentTypes = requireNonNull(argumentTypes, "argumentTypes is null"); + this.argumentChannels = ImmutableList.copyOf(argumentChannels); + this.sortOrders = ImmutableList.copyOf(sortOrders); + this.pagesIndex = requireNonNull(pagesIndex, "pagesIndex is null"); + this.sortChannels = IntStream.range(argumentTypes.size() - sortOrders.size(), argumentChannels.size()) + .boxed() + .toList(); + resetPageBuilder(); + } + + private void resetPageBuilder() + { + pageBuilder = new PageBuilder(argumentTypes); + } + + @Override + public long getEstimatedSize() + { + return delegate.getEstimatedSize() + initialDelegate.getEstimatedSize() + pagesIndex.getEstimatedSize().toBytes() + pageBuilder.getRetainedSizeInBytes(); + } + + @Override + public WindowAccumulator copy() + { + PagesIndex pagesIndexCopy = pagesIndexFactory.newPagesIndex(argumentTypes, pagesIndex.getPositionCount()); + pagesIndex.getPages().forEachRemaining(pagesIndexCopy::addPage); + return new OrderedWindowAccumulator(pagesIndexFactory, delegate.copy(), argumentTypes, argumentChannels, sortOrders, pagesIndexCopy); + } + + @Override + public void addInput(WindowIndex index, int startPosition, int endPosition) + { + if (pagesIndexSorted) { + pagesIndexSorted = false; + // operate on delegate as of start + // nicer would be to add reset() method to WindowAccumulator but it requires reset method in each AccumulatorState class + delegate = initialDelegate.copy(); + } + // index is remapped so just go from 0 to argumentChannels.size() + for (int position = startPosition; position <= endPosition; position++) { + if (pageBuilder.isFull()) { + pagesIndex.addPage(pageBuilder.build()); + resetPageBuilder(); + } + for (int channel = 0; channel < argumentChannels.size(); channel++) { + ValueBlock value = index.getSingleValueBlock(channel, position).getSingleValueBlock(0); + pageBuilder.getBlockBuilder(channel).append(value, 0); + } + pageBuilder.declarePosition(); + } + } + + @Override + public void output(BlockBuilder blockBuilder) + { + if (!pagesIndexSorted) { + if (!pageBuilder.isEmpty()) { + pagesIndex.addPage(pageBuilder.build()); + resetPageBuilder(); + } + int positionCount = pagesIndex.getPositionCount(); + if (positionCount == 0) { + return; + } + pagesIndex.sort(sortChannels, sortOrders); + WindowIndex sortedWindowIndex = new PagesWindowIndex(pagesIndex, 0, positionCount); + delegate.addInput(sortedWindowIndex, 0, positionCount - 1); + pagesIndexSorted = true; + } + checkState(pageBuilder.isEmpty()); + + delegate.output(blockBuilder); + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/window/AggregateWindowFunction.java b/core/trino-main/src/main/java/io/trino/operator/window/AggregateWindowFunction.java index 28d2f790d7af37..ef1da3562b4baf 100644 --- a/core/trino-main/src/main/java/io/trino/operator/window/AggregateWindowFunction.java +++ b/core/trino-main/src/main/java/io/trino/operator/window/AggregateWindowFunction.java @@ -24,7 +24,7 @@ import static java.lang.Math.min; import static java.util.Objects.requireNonNull; -class AggregateWindowFunction +public class AggregateWindowFunction implements WindowFunction { private final Supplier accumulatorFactory; @@ -63,7 +63,6 @@ else if ((frameStart == currentStart) && (frameEnd >= currentEnd)) { else { buildNewFrame(frameStart, frameEnd); } - accumulator.output(output); } diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionTreeUtils.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionTreeUtils.java index b7a115f80f50c3..cc7c786cc88eec 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionTreeUtils.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionTreeUtils.java @@ -85,9 +85,10 @@ public static List extractExpressions( private static boolean isAggregation(FunctionCall functionCall, Session session, FunctionResolver functionResolver, AccessControl accessControl) { - return ((functionResolver.isAggregationFunction(session, functionCall.getName(), accessControl) || functionCall.getFilter().isPresent()) - && functionCall.getWindow().isEmpty()) - || functionCall.getOrderBy().isPresent(); + return (functionResolver.isAggregationFunction(session, functionCall.getName(), accessControl) + || functionCall.getFilter().isPresent() + || functionCall.getOrderBy().isPresent()) + && functionCall.getWindow().isEmpty(); } private static boolean isWindow(FunctionCall functionCall, Session session, FunctionResolver functionResolver, AccessControl accessControl) diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java index 3e15fd403f2801..fc6da8458ab3b1 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java @@ -4334,10 +4334,6 @@ private List analyzeWindowFunctions(QuerySpecification node, List< throw semanticException(NOT_SUPPORTED, node, "FILTER is not yet supported for window functions"); } - if (windowFunction.getOrderBy().isPresent()) { - throw semanticException(NOT_SUPPORTED, windowFunction, "Window function with ORDER BY is not supported"); - } - List nestedWindowExpressions = extractWindowExpressions(windowFunction.getArguments()); if (!nestedWindowExpressions.isEmpty()) { throw semanticException(NESTED_WINDOW, nestedWindowExpressions.getFirst(), "Cannot nest window functions or row pattern measures inside window function arguments"); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java index 10ed0de5595661..1e8612c47560f6 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java @@ -103,6 +103,7 @@ import io.trino.operator.aggregation.AggregatorFactory; import io.trino.operator.aggregation.DistinctAccumulatorFactory; import io.trino.operator.aggregation.OrderedAccumulatorFactory; +import io.trino.operator.aggregation.OrderedWindowAccumulator; import io.trino.operator.aggregation.partial.PartialAggregationController; import io.trino.operator.exchange.LocalExchange; import io.trino.operator.exchange.LocalExchangeSinkOperator.LocalExchangeSinkOperatorFactory; @@ -134,6 +135,7 @@ import io.trino.operator.project.PageProcessor; import io.trino.operator.project.PageProjection; import io.trino.operator.unnest.UnnestOperator; +import io.trino.operator.window.AggregateWindowFunction; import io.trino.operator.window.AggregationWindowFunctionSupplier; import io.trino.operator.window.FrameInfo; import io.trino.operator.window.PartitionerSupplier; @@ -171,6 +173,7 @@ import io.trino.spi.function.CatalogSchemaFunctionName; import io.trino.spi.function.FunctionId; import io.trino.spi.function.FunctionKind; +import io.trino.spi.function.WindowFunction; import io.trino.spi.function.WindowFunctionSupplier; import io.trino.spi.function.table.TableFunctionProcessorProvider; import io.trino.spi.predicate.Domain; @@ -1114,7 +1117,6 @@ public PhysicalOperation visitWindow(WindowNode node, LocalExecutionPlanContext Optional sortKeyChannelForEndComparison = Optional.empty(); Optional sortKeyChannel = Optional.empty(); Optional ordering = Optional.empty(); - Frame frame = entry.getValue().getFrame(); if (frame.getStartValue().isPresent()) { frameStartChannel = Optional.of(source.getLayout().get(frame.getStartValue().get())); @@ -1145,15 +1147,15 @@ public PhysicalOperation visitWindow(WindowNode node, LocalExecutionPlanContext WindowNode.Function function = entry.getValue(); ResolvedFunction resolvedFunction = function.getResolvedFunction(); - ImmutableList.Builder arguments = ImmutableList.builder(); + ArrayList argumentChannels = new ArrayList<>(); for (Expression argument : function.getArguments()) { if (!(argument instanceof Lambda)) { Symbol argumentSymbol = Symbol.from(argument); - arguments.add(source.getLayout().get(argumentSymbol)); + argumentChannels.add(source.getLayout().get(argumentSymbol)); } } Symbol symbol = entry.getKey(); - WindowFunctionSupplier windowFunctionSupplier = getWindowFunctionImplementation(resolvedFunction); + Type type = resolvedFunction.signature().getReturnType(); List lambdas = function.getArguments().stream() @@ -1165,8 +1167,33 @@ public PhysicalOperation visitWindow(WindowNode node, LocalExecutionPlanContext .map(FunctionType.class::cast) .collect(toImmutableList()); + WindowFunctionSupplier windowFunctionSupplier; + if (function.getOrderingScheme().isPresent()) { + OrderingScheme orderingScheme = function.getOrderingScheme().orElseThrow(); + List sortKeys = orderingScheme.orderBy(); + List sortOrders = sortKeys.stream() + .map(orderingScheme::ordering) + .collect(toImmutableList()); + sortKeys.forEach(orderingArgumentSymbol -> argumentChannels.add(source.getLayout().get(orderingArgumentSymbol))); + + List argumentTypes = argumentChannels.stream() + .map(channel -> source.getTypes().get(channel)) + .collect(toImmutableList()); + + windowFunctionSupplier = getOrderedWindowFunctionImplementation( + resolvedFunction, + argumentTypes, + argumentChannels, + sortOrders); + } + else { + windowFunctionSupplier = getWindowFunctionImplementation(resolvedFunction); + } + List> lambdaProviders = makeLambdaProviders(lambdas, windowFunctionSupplier.getLambdaInterfaces(), functionTypes); - windowFunctionsBuilder.add(window(windowFunctionSupplier, type, frameInfo, function.isIgnoreNulls(), lambdaProviders, arguments.build())); + WindowFunctionDefinition windowFunction = window(windowFunctionSupplier, type, frameInfo, function.isIgnoreNulls(), lambdaProviders, ImmutableList.copyOf(argumentChannels)); + + windowFunctionsBuilder.add(windowFunction); windowFunctionOutputSymbolsBuilder.add(symbol); } @@ -1210,17 +1237,47 @@ public PhysicalOperation visitWindow(WindowNode node, LocalExecutionPlanContext private WindowFunctionSupplier getWindowFunctionImplementation(ResolvedFunction resolvedFunction) { if (resolvedFunction.functionKind() == FunctionKind.AGGREGATE) { - return uncheckedCacheGet(aggregationWindowFunctionSupplierCache, new FunctionKey(resolvedFunction.functionId(), resolvedFunction.signature()), () -> { - AggregationImplementation aggregationImplementation = plannerContext.getFunctionManager().getAggregationImplementation(resolvedFunction); - return new AggregationWindowFunctionSupplier( - resolvedFunction.signature(), - aggregationImplementation, - resolvedFunction.functionNullability()); - }); + return getAggregationWindowFunctionSupplier(resolvedFunction); } return plannerContext.getFunctionManager().getWindowFunctionSupplier(resolvedFunction); } + private AggregationWindowFunctionSupplier getAggregationWindowFunctionSupplier(ResolvedFunction resolvedFunction) + { + checkArgument(resolvedFunction.functionKind() == FunctionKind.AGGREGATE); + return uncheckedCacheGet(aggregationWindowFunctionSupplierCache, new FunctionKey(resolvedFunction.functionId(), resolvedFunction.signature()), () -> { + AggregationImplementation aggregationImplementation = plannerContext.getFunctionManager().getAggregationImplementation(resolvedFunction); + return new AggregationWindowFunctionSupplier( + resolvedFunction.signature(), + aggregationImplementation, + resolvedFunction.functionNullability()); + }); + } + + private WindowFunctionSupplier getOrderedWindowFunctionImplementation( + ResolvedFunction resolvedFunction, + List argumentTypes, + List argumentChannels, + List sortOrders) + { + AggregationWindowFunctionSupplier aggregationWindowFunctionSupplier = getAggregationWindowFunctionSupplier(resolvedFunction); + return new WindowFunctionSupplier() { + @Override + public WindowFunction createWindowFunction(boolean ignoreNulls, List> lambdaProviders) + { + AggregationImplementation aggregationImplementation = plannerContext.getFunctionManager().getAggregationImplementation(resolvedFunction); + boolean hasRemoveInput = aggregationImplementation.getWindowAccumulator().isPresent(); + return new AggregateWindowFunction(() -> new OrderedWindowAccumulator(pagesIndexFactory, aggregationWindowFunctionSupplier.createWindowAccumulator(lambdaProviders), argumentTypes, argumentChannels, sortOrders), hasRemoveInput); + } + + @Override + public List> getLambdaInterfaces() + { + return aggregationWindowFunctionSupplier.getLambdaInterfaces(); + } + }; + } + @Override public PhysicalOperation visitPatternRecognition(PatternRecognitionNode node, LocalExecutionPlanContext context) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java index 8278515b22a771..759a2237108c6b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/QueryPlanner.java @@ -314,7 +314,7 @@ public RelationPlan planExpand(Query query) NodeAndMappings checkConvergenceStep = copy(recursionStep, mappings); Symbol countSymbol = symbolAllocator.newSymbol("count", BIGINT); ResolvedFunction function = plannerContext.getMetadata().resolveBuiltinFunction("count", ImmutableList.of()); - WindowNode.Function countFunction = new WindowNode.Function(function, ImmutableList.of(), DEFAULT_FRAME, false); + WindowNode.Function countFunction = new WindowNode.Function(function, ImmutableList.of(), Optional.empty(), DEFAULT_FRAME, false); WindowNode windowNode = new WindowNode( idAllocator.getNextId(), @@ -1457,6 +1457,9 @@ private PlanBuilder planWindowFunctions(Node node, PlanBuilder subPlan, List !(argument instanceof LambdaExpression)) // lambda expression is generated at execution time .collect(Collectors.toList())); + inputsBuilder.addAll(getSortItemsFromOrderBy(windowFunction.getOrderBy()).stream() + .map(SortItem::getSortKey) + .iterator()); } List inputs = inputsBuilder.build(); @@ -1778,6 +1781,7 @@ private PlanBuilder planWindow( return coercions.get(argument).toSymbolReference(); }) .collect(toImmutableList()), + windowFunction.getOrderBy().map(orderBy -> translateOrderingScheme(orderBy.getSortItems(), coercions::get)), frame, nullTreatment == NullTreatment.IGNORE); @@ -1859,6 +1863,7 @@ private PlanBuilder planPatternRecognition( return coercions.get(argument).toSymbolReference(); }) .collect(toImmutableList()), + Optional.empty(), baseFrame, nullTreatment == NullTreatment.IGNORE); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DecorrelateUnnest.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DecorrelateUnnest.java index ea701961326357..de62cbe6d6a857 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DecorrelateUnnest.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DecorrelateUnnest.java @@ -464,6 +464,7 @@ public RewriteResult visitTopN(TopNNode node, Void context) WindowNode.Function rowNumberFunction = new WindowNode.Function( metadata.resolveBuiltinFunction("row_number", ImmutableList.of()), ImmutableList.of(), + Optional.empty(), DEFAULT_FRAME, false); WindowNode windowNode = new WindowNode( diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementLimitWithTies.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementLimitWithTies.java index 3942f900480f2c..f7afe14a5958fb 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementLimitWithTies.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementLimitWithTies.java @@ -123,6 +123,7 @@ public static PlanNode rewriteLimitWithTiesWithPartitioning(LimitNode limitNode, WindowNode.Function rankFunction = new WindowNode.Function( metadata.resolveBuiltinFunction("rank", ImmutableList.of()), ImmutableList.of(), + Optional.empty(), DEFAULT_FRAME, false); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementTableFunctionSource.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementTableFunctionSource.java index 673926e0d3f60c..d80b48d7a8e980 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementTableFunctionSource.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementTableFunctionSource.java @@ -317,8 +317,8 @@ private static NodeWithSymbols planWindowFunctionsForSource( source, specification, ImmutableMap.of( - rowNumber, new WindowNode.Function(rowNumberFunction, ImmutableList.of(), FULL_FRAME, false), - partitionSize, new WindowNode.Function(countFunction, ImmutableList.of(), FULL_FRAME, false)), + rowNumber, new WindowNode.Function(rowNumberFunction, ImmutableList.of(), Optional.empty(), FULL_FRAME, false), + partitionSize, new WindowNode.Function(countFunction, ImmutableList.of(), Optional.empty(), FULL_FRAME, false)), Optional.empty(), ImmutableSet.of(), 0); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneWindowColumns.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneWindowColumns.java index 2091f1dcb859ac..b1a95f72410ace 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneWindowColumns.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneWindowColumns.java @@ -59,6 +59,7 @@ protected Optional pushDownProjectOff(Context context, WindowNode wind windowNode.getHashSymbol().ifPresent(referencedInputs::add); for (WindowNode.Function windowFunction : referencedFunctions.values()) { + windowFunction.getOrderingScheme().ifPresent(orderingScheme -> referencedInputs.addAll(orderingScheme.orderBy())); referencedInputs.addAll(SymbolsExtractor.extractUnique(windowFunction)); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughWindow.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughWindow.java index eabcdeae5978b6..0144f80148ef69 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughWindow.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushDownDereferencesThroughWindow.java @@ -138,6 +138,7 @@ public Result apply(ProjectNode projectNode, Captures captures, Context context) oldFunction.getArguments().stream() .map(expression -> replaceExpression(expression, mappings)) .collect(toImmutableList()), + oldFunction.getOrderingScheme(), // TODO ????????? oldFunction.getFrame(), oldFunction.isIgnoreNulls()); })), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SetOperationNodeTranslator.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SetOperationNodeTranslator.java index 892065e06d668a..65f81f028c25c5 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SetOperationNodeTranslator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SetOperationNodeTranslator.java @@ -194,6 +194,7 @@ private WindowNode appendCounts(UnionNode sourceNode, List originalColum functions.put(output, new WindowNode.Function( countFunction, ImmutableList.of(markers.get(i).toSymbolReference()), + Optional.empty(), defaultFrame, false)); } @@ -201,6 +202,7 @@ private WindowNode appendCounts(UnionNode sourceNode, List originalColum functions.put(rowNumberSymbol, new WindowNode.Function( rowNumberFunction, ImmutableList.of(), + Optional.empty(), defaultFrame, false)); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java index 4257e0df3210d2..a0a87edea09cc2 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java @@ -69,6 +69,7 @@ import java.util.function.Function; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; @@ -248,8 +249,9 @@ public WindowNode map(WindowNode node, PlanNode source) .map(this::map) .collect(toImmutableList()); WindowNode.Frame newFrame = map(function.getFrame()); + Optional newOrderingScheme = function.getOrderingScheme().map(this::map); - newFunctions.put(map(symbol), new WindowNode.Function(function.getResolvedFunction(), newArguments, newFrame, function.isIgnoreNulls())); + newFunctions.put(map(symbol), new WindowNode.Function(function.getResolvedFunction(), newArguments, newOrderingScheme, newFrame, function.isIgnoreNulls())); }); SpecificationWithPreSortedPrefix newSpecification = mapAndDistinct(node.getSpecification(), node.getPreSortedOrderPrefix()); @@ -307,8 +309,9 @@ public PatternRecognitionNode map(PatternRecognitionNode node, PlanNode source) .map(this::map) .collect(toImmutableList()); WindowNode.Frame newFrame = map(function.getFrame()); + verify(function.getOrderingScheme().isEmpty()); - newFunctions.put(map(symbol), new WindowNode.Function(function.getResolvedFunction(), newArguments, newFrame, function.isIgnoreNulls())); + newFunctions.put(map(symbol), new WindowNode.Function(function.getResolvedFunction(), newArguments, Optional.empty(), newFrame, function.isIgnoreNulls())); }); ImmutableMap.Builder newMeasures = ImmutableMap.builder(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/WindowNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/WindowNode.java index eea37b3f02f3d7..cfd797cd57ac9a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/WindowNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/WindowNode.java @@ -290,6 +290,7 @@ public static final class Function { private final ResolvedFunction resolvedFunction; private final List arguments; + private final Optional orderingScheme; private final Frame frame; private final boolean ignoreNulls; @@ -297,11 +298,13 @@ public static final class Function public Function( @JsonProperty("resolvedFunction") ResolvedFunction resolvedFunction, @JsonProperty("arguments") List arguments, + @JsonProperty("orderingScheme") Optional orderingScheme, @JsonProperty("frame") Frame frame, @JsonProperty("ignoreNulls") boolean ignoreNulls) { this.resolvedFunction = requireNonNull(resolvedFunction, "resolvedFunction is null"); this.arguments = requireNonNull(arguments, "arguments is null"); + this.orderingScheme = requireNonNull(orderingScheme, "orderingScheme is null"); this.frame = requireNonNull(frame, "frame is null"); this.ignoreNulls = ignoreNulls; } @@ -318,6 +321,12 @@ public List getArguments() return arguments; } + @JsonProperty + public Optional getOrderingScheme() + { + return orderingScheme; + } + @JsonProperty public Frame getFrame() { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java index 602acd282c1e30..1b773846e3fa7e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java @@ -912,11 +912,16 @@ public Void visitWindow(WindowNode node, Context context) WindowNode.Function function = entry.getValue(); String frameInfo = formatFrame(function.getFrame()); + String orderingInfo = function.getOrderingScheme().map(orderingScheme -> " " + orderingScheme.orderBy().stream() + .map(input -> anonymizer.anonymize(input) + " " + orderingScheme.ordering(input)) + .collect(joining(", "))).orElse(""); + nodeOutput.appendDetails( - "%s := %s(%s) %s", + "%s := %s(%s%s) %s", anonymizer.anonymize(entry.getKey()), formatFunctionName(function.getResolvedFunction()), Joiner.on(", ").join(anonymizeExpressions(function.getArguments())), + orderingInfo, frameInfo); } return processChildren(node, new Context(context.isInitialPlan())); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestTypeValidator.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestTypeValidator.java index abf40c00f72dc9..71418ac88add72 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestTypeValidator.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestTypeValidator.java @@ -143,7 +143,7 @@ public void testValidWindow() Optional.empty(), Optional.empty()); - WindowNode.Function function = new WindowNode.Function(resolvedFunction, ImmutableList.of(columnC.toSymbolReference()), frame, false); + WindowNode.Function function = new WindowNode.Function(resolvedFunction, ImmutableList.of(columnC.toSymbolReference()), Optional.empty(), frame, false); DataOrganizationSpecification specification = new DataOrganizationSpecification(ImmutableList.of(), Optional.empty()); @@ -238,7 +238,7 @@ public void testInvalidWindowFunctionCall() Optional.empty(), Optional.empty()); - WindowNode.Function function = new WindowNode.Function(resolvedFunction, ImmutableList.of(columnA.toSymbolReference()), frame, false); + WindowNode.Function function = new WindowNode.Function(resolvedFunction, ImmutableList.of(columnA.toSymbolReference()), Optional.empty(), frame, false); DataOrganizationSpecification specification = new DataOrganizationSpecification(ImmutableList.of(), Optional.empty()); @@ -271,7 +271,7 @@ public void testInvalidWindowFunctionSignature() Optional.empty(), Optional.empty()); - WindowNode.Function function = new WindowNode.Function(resolvedFunction, ImmutableList.of(columnC.toSymbolReference()), frame, false); + WindowNode.Function function = new WindowNode.Function(resolvedFunction, ImmutableList.of(columnC.toSymbolReference()), Optional.empty(), frame, false); DataOrganizationSpecification specification = new DataOrganizationSpecification(ImmutableList.of(), Optional.empty()); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeAdjacentWindows.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeAdjacentWindows.java index 1e36b7c76f8788..2f29f960f02992 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeAdjacentWindows.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergeAdjacentWindows.java @@ -224,6 +224,7 @@ private static WindowNode.Function newWindowNodeFunction(ResolvedFunction resolv Arrays.stream(symbols) .map(Symbol::toSymbolReference) .collect(Collectors.toList()), + Optional.empty(), DEFAULT_FRAME, false); } @@ -235,6 +236,7 @@ private static WindowNode.Function newWindowNodeFunction(ResolvedFunction resolv Arrays.stream(symbols) .map(name -> new Reference(DOUBLE, name)) .collect(Collectors.toList()), + Optional.empty(), DEFAULT_FRAME, false); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergePatternRecognitionNodes.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergePatternRecognitionNodes.java index e0e3fcc46facc8..98a0cc8f4a17b9 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergePatternRecognitionNodes.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMergePatternRecognitionNodes.java @@ -166,7 +166,7 @@ public void testParentDependsOnSourceCreatedOutputs() .pattern(new IrLabel("X")) .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.patternRecognition(childBuilder -> childBuilder - .addWindowFunction(p.symbol("function"), new WindowNode.Function(lag, ImmutableList.of(p.symbol("a").toSymbolReference()), DEFAULT_FRAME, false)) + .addWindowFunction(p.symbol("function"), new WindowNode.Function(lag, ImmutableList.of(p.symbol("a").toSymbolReference()), Optional.empty(), DEFAULT_FRAME, false)) .rowsPerMatch(WINDOW) .frame(new WindowNode.Frame(ROWS, CURRENT_ROW, Optional.empty(), Optional.empty(), UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty())) .pattern(new IrLabel("X")) @@ -177,13 +177,13 @@ public void testParentDependsOnSourceCreatedOutputs() // parent node's window function depends on child node's window function output tester().assertThat(new MergePatternRecognitionNodesWithoutProject()) .on(p -> p.patternRecognition(parentBuilder -> parentBuilder - .addWindowFunction(p.symbol("dependent"), new WindowNode.Function(lag, ImmutableList.of(p.symbol("function").toSymbolReference()), DEFAULT_FRAME, false)) + .addWindowFunction(p.symbol("dependent"), new WindowNode.Function(lag, ImmutableList.of(p.symbol("function").toSymbolReference()), Optional.empty(), DEFAULT_FRAME, false)) .rowsPerMatch(WINDOW) .frame(new WindowNode.Frame(ROWS, CURRENT_ROW, Optional.empty(), Optional.empty(), UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty())) .pattern(new IrLabel("X")) .addVariableDefinition(new IrLabel("X"), TRUE) .source(p.patternRecognition(childBuilder -> childBuilder - .addWindowFunction(p.symbol("function"), new WindowNode.Function(lag, ImmutableList.of(p.symbol("a").toSymbolReference()), DEFAULT_FRAME, false)) + .addWindowFunction(p.symbol("function"), new WindowNode.Function(lag, ImmutableList.of(p.symbol("a").toSymbolReference()), Optional.empty(), DEFAULT_FRAME, false)) .rowsPerMatch(WINDOW) .frame(new WindowNode.Frame(ROWS, CURRENT_ROW, Optional.empty(), Optional.empty(), UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty())) .pattern(new IrLabel("X")) @@ -194,7 +194,7 @@ public void testParentDependsOnSourceCreatedOutputs() // parent node's window function depends on child node's measure output tester().assertThat(new MergePatternRecognitionNodesWithoutProject()) .on(p -> p.patternRecognition(parentBuilder -> parentBuilder - .addWindowFunction(p.symbol("dependent"), new WindowNode.Function(lag, ImmutableList.of(p.symbol("measure").toSymbolReference()), DEFAULT_FRAME, false)) + .addWindowFunction(p.symbol("dependent"), new WindowNode.Function(lag, ImmutableList.of(p.symbol("measure").toSymbolReference()), Optional.empty(), DEFAULT_FRAME, false)) .rowsPerMatch(WINDOW) .frame(new WindowNode.Frame(ROWS, CURRENT_ROW, Optional.empty(), Optional.empty(), UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty())) .pattern(new IrLabel("X")) @@ -306,7 +306,7 @@ public void testMergeWithoutProject() ImmutableMap.of("pointer", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(new IrLabel("X")), true, true, 0, 0), new Symbol(BIGINT, "b")))) - .addWindowFunction(p.symbol("parent_function"), new WindowNode.Function(lag, ImmutableList.of(p.symbol("a").toSymbolReference()), DEFAULT_FRAME, false)) + .addWindowFunction(p.symbol("parent_function"), new WindowNode.Function(lag, ImmutableList.of(p.symbol("a").toSymbolReference()), Optional.empty(), DEFAULT_FRAME, false)) .rowsPerMatch(WINDOW) .frame(new WindowNode.Frame(ROWS, CURRENT_ROW, Optional.empty(), Optional.empty(), UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty())) .skipTo(LAST, ImmutableSet.of(new IrLabel("X"))) @@ -324,7 +324,7 @@ public void testMergeWithoutProject() ImmutableMap.of("pointer", new ScalarValuePointer( new LogicalIndexPointer(ImmutableSet.of(new IrLabel("X")), false, true, 0, 0), new Symbol(BIGINT, "a")))) - .addWindowFunction(p.symbol("child_function"), new WindowNode.Function(lag, ImmutableList.of(p.symbol("b").toSymbolReference()), DEFAULT_FRAME, false)) + .addWindowFunction(p.symbol("child_function"), new WindowNode.Function(lag, ImmutableList.of(p.symbol("b").toSymbolReference()), Optional.empty(), DEFAULT_FRAME, false)) .rowsPerMatch(WINDOW) .frame(new WindowNode.Frame(ROWS, CURRENT_ROW, Optional.empty(), Optional.empty(), UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty())) .skipTo(LAST, ImmutableSet.of(new IrLabel("X"))) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPrunePattenRecognitionColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPrunePattenRecognitionColumns.java index dd3c54bdca54db..6fcdcf768f14f3 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPrunePattenRecognitionColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPrunePattenRecognitionColumns.java @@ -103,7 +103,7 @@ public void testRemovePatternRecognitionNode() .on(p -> p.project( Assignments.identity(p.symbol("b")), p.patternRecognition(builder -> builder - .addWindowFunction(p.symbol("rank"), new WindowNode.Function(rank, ImmutableList.of(), DEFAULT_FRAME, false)) + .addWindowFunction(p.symbol("rank"), new WindowNode.Function(rank, ImmutableList.of(), Optional.empty(), DEFAULT_FRAME, false)) .addMeasure( p.symbol("measure"), new Reference(BIGINT, "pointer"), @@ -132,7 +132,7 @@ public void testPruneUnreferencedWindowFunctionAndSources() .on(p -> p.project( Assignments.identity(p.symbol("measure", BIGINT)), p.patternRecognition(builder -> builder - .addWindowFunction(p.symbol("lag", BIGINT), new WindowNode.Function(lag, ImmutableList.of(p.symbol("b", BIGINT).toSymbolReference()), DEFAULT_FRAME, false)) + .addWindowFunction(p.symbol("lag", BIGINT), new WindowNode.Function(lag, ImmutableList.of(p.symbol("b", BIGINT).toSymbolReference()), Optional.empty(), DEFAULT_FRAME, false)) .addMeasure( p.symbol("measure", BIGINT), new Reference(BIGINT, "pointer"), @@ -191,7 +191,7 @@ public void testPruneUnreferencedMeasureAndSources() .on(p -> p.project( Assignments.identity(p.symbol("lag")), p.patternRecognition(builder -> builder - .addWindowFunction(p.symbol("lag"), new WindowNode.Function(lag, ImmutableList.of(p.symbol("b").toSymbolReference()), frame, false)) + .addWindowFunction(p.symbol("lag"), new WindowNode.Function(lag, ImmutableList.of(p.symbol("b").toSymbolReference()), Optional.empty(), frame, false)) .addMeasure( p.symbol("measure"), new Reference(BIGINT, "pointer"), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneWindowColumns.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneWindowColumns.java index 9f9b4fe3e78ab1..c9321850741e47 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneWindowColumns.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneWindowColumns.java @@ -206,6 +206,7 @@ private static PlanNode buildProjectedWindow( new WindowNode.Function( MIN_FUNCTION, ImmutableList.of(input1.toSymbolReference()), + Optional.empty(), new WindowNode.Frame( RANGE, UNBOUNDED_PRECEDING, @@ -219,6 +220,7 @@ private static PlanNode buildProjectedWindow( new WindowNode.Function( MIN_FUNCTION, ImmutableList.of(input2.toSymbolReference()), + Optional.empty(), new WindowNode.Frame( RANGE, UNBOUNDED_PRECEDING, diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushDownDereferencesRules.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushDownDereferencesRules.java index 3852963cb92ecb..657ddacbbc6fa5 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushDownDereferencesRules.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushDownDereferencesRules.java @@ -626,6 +626,7 @@ public void testPushdownDereferenceThroughWindow() new WindowNode.Function( createTestMetadataManager().resolveBuiltinFunction("min", fromTypes(ROW_TYPE)), ImmutableList.of(p.symbol("msg3", ROW_TYPE).toSymbolReference()), + Optional.empty(), new WindowNode.Frame( RANGE, UNBOUNDED_PRECEDING, diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPredicateThroughProjectIntoWindow.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPredicateThroughProjectIntoWindow.java index a3321fd0637abc..588a97e65ad485 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPredicateThroughProjectIntoWindow.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushPredicateThroughProjectIntoWindow.java @@ -297,6 +297,7 @@ private Function rowNumberFunction() return new Function( tester().getMetadata().resolveBuiltinFunction("row_number", fromTypes()), ImmutableList.of(), + Optional.empty(), DEFAULT_FRAME, false); } @@ -306,6 +307,7 @@ private Function rankFunction() return new Function( tester().getMetadata().resolveBuiltinFunction("rank", fromTypes()), ImmutableList.of(), + Optional.empty(), DEFAULT_FRAME, false); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushdownFilterIntoWindow.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushdownFilterIntoWindow.java index 004cd58deac4cd..de7644028a386d 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushdownFilterIntoWindow.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushdownFilterIntoWindow.java @@ -171,6 +171,7 @@ private static WindowNode.Function newWindowNodeFunction(ResolvedFunction resolv return new WindowNode.Function( resolvedFunction, ImmutableList.of(symbol.toSymbolReference()), + Optional.empty(), DEFAULT_FRAME, false); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushdownLimitIntoWindow.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushdownLimitIntoWindow.java index 5d62679ff92781..7caff8f6350ea8 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushdownLimitIntoWindow.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushdownLimitIntoWindow.java @@ -201,6 +201,7 @@ private static WindowNode.Function newWindowNodeFunction(ResolvedFunction resolv return new WindowNode.Function( resolvedFunction, ImmutableList.of(symbol.toSymbolReference()), + Optional.empty(), DEFAULT_FRAME, false); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceWindowWithRowNumber.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceWindowWithRowNumber.java index 8264bc626070e0..bf83a59909e588 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceWindowWithRowNumber.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestReplaceWindowWithRowNumber.java @@ -115,6 +115,7 @@ private static WindowNode.Function newWindowNodeFunction(ResolvedFunction resolv return new WindowNode.Function( resolvedFunction, ImmutableList.of(), + Optional.empty(), DEFAULT_FRAME, false); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java index 74c01fec9c9f75..34a152f797e57f 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestSwapAdjacentWindowsBySpecifications.java @@ -61,7 +61,7 @@ public void doesNotFireOnPlanWithSingleWindowNode() ImmutableList.of(p.symbol("a")), Optional.empty()), ImmutableMap.of(p.symbol("avg_1"), - new WindowNode.Function(resolvedFunction, ImmutableList.of(), DEFAULT_FRAME, false)), + new WindowNode.Function(resolvedFunction, ImmutableList.of(), Optional.empty(), DEFAULT_FRAME, false)), p.values(p.symbol("a")))) .doesNotFire(); } @@ -81,12 +81,12 @@ public void subsetComesFirst() ImmutableList.of(p.symbol("a")), Optional.empty()), ImmutableMap.of(p.symbol("avg_1", DOUBLE), - new WindowNode.Function(resolvedFunction, ImmutableList.of(new Reference(BIGINT, "a")), DEFAULT_FRAME, false)), + new WindowNode.Function(resolvedFunction, ImmutableList.of(new Reference(BIGINT, "a")), Optional.empty(), DEFAULT_FRAME, false)), p.window(new DataOrganizationSpecification( ImmutableList.of(p.symbol("a"), p.symbol("b")), Optional.empty()), ImmutableMap.of(p.symbol("avg_2", DOUBLE), - new WindowNode.Function(resolvedFunction, ImmutableList.of(new Reference(BIGINT, "b")), DEFAULT_FRAME, false)), + new WindowNode.Function(resolvedFunction, ImmutableList.of(new Reference(BIGINT, "b")), Optional.empty(), DEFAULT_FRAME, false)), p.values(p.symbol("a"), p.symbol("b"))))) .matches( window(windowMatcherBuilder -> windowMatcherBuilder @@ -107,12 +107,12 @@ public void dependentWindowsAreNotReordered() ImmutableList.of(p.symbol("a", BIGINT)), Optional.empty()), ImmutableMap.of(p.symbol("avg_1", DOUBLE), - new WindowNode.Function(resolvedFunction, ImmutableList.of(new Reference(DOUBLE, "avg_2")), DEFAULT_FRAME, false)), + new WindowNode.Function(resolvedFunction, ImmutableList.of(new Reference(DOUBLE, "avg_2")), Optional.empty(), DEFAULT_FRAME, false)), p.window(new DataOrganizationSpecification( ImmutableList.of(p.symbol("a", BIGINT), p.symbol("b", BIGINT)), Optional.empty()), ImmutableMap.of(p.symbol("avg_2", DOUBLE), - new WindowNode.Function(resolvedFunction, ImmutableList.of(new Reference(BIGINT, "a")), DEFAULT_FRAME, false)), + new WindowNode.Function(resolvedFunction, ImmutableList.of(new Reference(BIGINT, "a")), Optional.empty(), DEFAULT_FRAME, false)), p.values(p.symbol("a", BIGINT), p.symbol("b", BIGINT))))) .doesNotFire(); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/plan/TestPatternRecognitionNodeSerialization.java b/core/trino-main/src/test/java/io/trino/sql/planner/plan/TestPatternRecognitionNodeSerialization.java index 0d508dc282aa29..55b542e3335ffa 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/plan/TestPatternRecognitionNodeSerialization.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/plan/TestPatternRecognitionNodeSerialization.java @@ -208,6 +208,7 @@ public void testPatternRecognitionNodeRoundtrip() new Function( rankFunction, ImmutableList.of(), + Optional.empty(), new Frame(ROWS, CURRENT_ROW, Optional.empty(), Optional.empty(), UNBOUNDED_FOLLOWING, Optional.empty(), Optional.empty()), false)), ImmutableMap.of( diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/plan/TestWindowNode.java b/core/trino-main/src/test/java/io/trino/sql/planner/plan/TestWindowNode.java index ae9209f87a80ba..2136adad78aafb 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/plan/TestWindowNode.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/plan/TestWindowNode.java @@ -93,7 +93,7 @@ public void testSerializationRoundtrip() Optional.of(new OrderingScheme( ImmutableList.of(columnB), ImmutableMap.of(columnB, SortOrder.ASC_NULLS_FIRST)))); - Map functions = ImmutableMap.of(windowSymbol, new WindowNode.Function(resolvedFunction, ImmutableList.of(columnC.toSymbolReference()), frame, false)); + Map functions = ImmutableMap.of(windowSymbol, new WindowNode.Function(resolvedFunction, ImmutableList.of(columnC.toSymbolReference()), Optional.empty(), frame, false)); Optional hashSymbol = Optional.of(columnB); Set prePartitionedInputs = ImmutableSet.of(columnA); WindowNode windowNode = new WindowNode( diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestListagg.java b/core/trino-main/src/test/java/io/trino/sql/query/TestListagg.java index 191478037e90fa..0c0857f75ea988 100644 --- a/core/trino-main/src/test/java/io/trino/sql/query/TestListagg.java +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestListagg.java @@ -16,17 +16,15 @@ import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; -import org.junit.jupiter.api.parallel.Execution; import static io.trino.spi.StandardErrorCode.EXCEEDED_FUNCTION_MEMORY_LIMIT; import static io.trino.spi.StandardErrorCode.SYNTAX_ERROR; import static io.trino.spi.block.PageBuilderStatus.DEFAULT_MAX_PAGE_SIZE_IN_BYTES; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; -import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT; @TestInstance(PER_CLASS) -@Execution(CONCURRENT) +//@Execution(CONCURRENT) public class TestListagg { private final QueryAssertions assertions = new QueryAssertions(); diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestOrderedAggregation.java b/core/trino-main/src/test/java/io/trino/sql/query/TestOrderedAggregation.java index f76431cddad512..c04f8782a2e763 100644 --- a/core/trino-main/src/test/java/io/trino/sql/query/TestOrderedAggregation.java +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestOrderedAggregation.java @@ -103,10 +103,6 @@ public void testAggregationWithOrderBy() "SELECT x, y, array_agg(z ORDER BY z) FROM (VALUES (1, 2, 3), (1, 2, 1), (2, 1, 3), (2, 1, 4)) t(x, y, z) GROUP BY GROUPING SETS ((x), (x, y))")) .matches("VALUES (1, NULL, ARRAY[1, 3]), (2, NULL, ARRAY[3, 4]), (1, 2, ARRAY[1, 3]), (2, 1, ARRAY[3, 4])"); - assertThat(assertions.query( - "SELECT array_agg(z ORDER BY z) OVER (PARTITION BY x) FROM (VALUES (1, 2, 3), (1, 2, 1), (2, 1, 3), (2, 1, 4)) t(x, y, z) GROUP BY x, z")) - .failure().hasMessageMatching(".* Window function with ORDER BY is not supported"); - assertThat(assertions.query( "SELECT array_agg(DISTINCT x ORDER BY y) FROM (VALUES (1, 2), (3, 5), (4, 1)) t(x, y)")) .failure().hasMessageMatching(".* For aggregate function with DISTINCT, ORDER BY expressions must appear in arguments"); diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestWindow.java b/core/trino-main/src/test/java/io/trino/sql/query/TestWindow.java index e42f14112a6c9a..54be6ab5fb481e 100644 --- a/core/trino-main/src/test/java/io/trino/sql/query/TestWindow.java +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestWindow.java @@ -72,4 +72,198 @@ WINDOW w AS (ORDER BY k ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) """)) .matches("VALUES (BIGINT '1', BIGINT '1', 1, 1, BIGINT '1', BIGINT '1', 1, 1, BIGINT '1', BIGINT '1', 1, 1, BIGINT '1', BIGINT '1', 1, 1, BIGINT '1', BIGINT '1', 1, 1, BIGINT '1', BIGINT '1', 1, 1)"); } + + @Test + public void testWindowWithOrderBy() + { + // window and aggregate ordering on different columns + assertThat(assertions.query(""" + SELECT + a, + ARRAY_AGG(c ORDER BY c) OVER w + FROM ( + VALUES (1, 1, 1), (1, 2, 2), (1, 3, 3), (1, 4, 4), (1, 4, 7) , (1, 5, 5), + (2, 1, 1), (2, 2, 3), (2, 3, 2), (2, 4, 1) + ) AS t(a, b, c) + WINDOW w AS (PARTITION BY a ORDER BY b) + """)) + .matches(""" + VALUES + (1, ARRAY[1]), + (1, ARRAY[1, 2]), + (1, ARRAY[1, 2, 3]), + (1, ARRAY[1, 2, 3, 4, 7]), + (1, ARRAY[1, 2, 3, 4, 7]), + (1, ARRAY[1, 2, 3, 4, 5, 7]), + (2, ARRAY[1]), + (2, ARRAY[1, 3]), + (2, ARRAY[1, 2, 3]), + (2, ARRAY[1, 1, 2, 3]) + """); + + // window and aggregate ordering on different columns (different ordering) + assertThat(assertions.query(""" + SELECT + a, + ARRAY_AGG(c ORDER BY c DESC) OVER w + FROM ( + VALUES (1, 1, 1), (1, 2, 2), (1, 3, 3), (1, 4, 4), (1, 4, 7), (1, 5, 5), + (2, 1, 1), (2, 2, 3), (2, 3, 2), (2, 4, 1) + ) AS t(a, b, c) + WINDOW w AS (PARTITION BY a ORDER BY b) + """)) + .matches(""" + VALUES + (1, ARRAY[1]), + (1, ARRAY[2, 1]), + (1, ARRAY[3, 2, 1]), + (1, ARRAY[7, 4, 3, 2, 1]), + (1, ARRAY[7, 4, 3, 2, 1]), + (1, ARRAY[7, 5, 4, 3, 2, 1]), + (2, ARRAY[1]), + (2, ARRAY[3, 1]), + (2, ARRAY[3, 2, 1]), + (2, ARRAY[3, 2, 1, 1]) + """); + + // aggregate ordering on column not in output + assertThat(assertions.query(""" + SELECT + a, + ARRAY_AGG(c ORDER BY d) OVER w + FROM ( + VALUES (1, 1, 1, 5), (1, 2, 2, 4), (1, 3, 3, 1), (1, 4, 4, 2), (1, 4, 7, 6) , (1, 5, 5, 3), + (2, 1, 1, 4), (2, 2, 3, 3), (2, 3, 2, 2), (2, 4, 1, 1) + ) AS t(a, b, c, d) + WINDOW w AS (PARTITION BY a ORDER BY b) + """)) + .matches(""" + VALUES + (1, ARRAY[1]), + (1, ARRAY[2, 1]), + (1, ARRAY[3, 2, 1]), + (1, ARRAY[3, 4, 2, 1, 7]), + (1, ARRAY[3, 4, 2, 1, 7]), + (1, ARRAY[3, 4, 5, 2, 1, 7]), + (2, ARRAY[1]), + (2, ARRAY[3, 1]), + (2, ARRAY[2, 3, 1]), + (2, ARRAY[1, 2, 3, 1]) + """); + + // window and aggregate ordering on the same column + assertThat(assertions.query(""" + SELECT + a, + ARRAY_AGG(c ORDER BY b) OVER w + FROM ( + VALUES (1, 1, 1), (1, 2, 2), (1, 3, 3), (1, 4, 4) , (1, 5, 5), + (2, 1, 1), (2, 2, 3), (2, 3, 2), (2, 4, 1) + ) AS t(a, b, c) + WINDOW w AS (PARTITION BY a ORDER BY b) + """)) + .matches(""" + VALUES + (1, ARRAY[1]), + (1, ARRAY[1, 2]), + (1, ARRAY[1, 2, 3]), + (1, ARRAY[1, 2, 3, 4]), + (1, ARRAY[1, 2, 3, 4, 5]), + (2, ARRAY[1]), + (2, ARRAY[1, 3]), + (2, ARRAY[1, 3, 2]), + (2, ARRAY[1, 3, 2, 1]) + """); + + // window and aggregate ordering on same column (different sort order) + assertThat(assertions.query(""" + SELECT + a, + ARRAY_AGG(c ORDER BY b DESC) OVER w + FROM ( + VALUES (1, 1, 1), (1, 2, 2), (1, 3, 3), (1, 4, 4), (1, 5, 5), + (2, 1, 1), (2, 2, 3), (2, 3, 2), (2, 4, 1) + ) AS t(a, b, c) + WINDOW w AS (PARTITION BY a ORDER BY b) + """)) + .matches(""" + VALUES + (1, ARRAY[1]), + (1, ARRAY[2, 1]), + (1, ARRAY[3, 2, 1]), + (1, ARRAY[4, 3, 2, 1]), + (1, ARRAY[5, 4, 3, 2, 1]), + (2, ARRAY[1]), + (2, ARRAY[3, 1]), + (2, ARRAY[2, 3, 1]), + (2, ARRAY[1, 2, 3, 1]) + """); + + // aggregate ordering on two columns (tiebreaker ASC) + assertThat(assertions.query(""" + SELECT + a, + ARRAY_AGG(c ORDER BY d, c) OVER w + FROM ( + VALUES (1, 1, 1, 5), (1, 2, 2, 4), (1, 3, 3, 4), (1, 4, 4, 5), (1, 4, 7, 1) , (1, 5, 5, 2) + ) AS t(a, b, c, d) + WINDOW w AS (PARTITION BY a ORDER BY b) + """)) + .matches(""" + VALUES + (1, ARRAY[1]), + (1, ARRAY[2, 1]), + (1, ARRAY[2, 3, 1]), + (1, ARRAY[7, 2, 3, 1, 4]), + (1, ARRAY[7, 2, 3, 1, 4]), + (1, ARRAY[7, 5, 2, 3, 1, 4]) + """); + + // aggregate ordering on two columns (tiebreaker DESC) + assertThat(assertions.query(""" + SELECT + a, + ARRAY_AGG(c ORDER BY d, c DESC) OVER w + FROM ( + VALUES (1, 1, 1, 5), (1, 2, 2, 4), (1, 3, 3, 4), (1, 4, 4, 5), (1, 4, 7, 1) , (1, 5, 5, 2) + ) AS t(a, b, c, d) + WINDOW w AS (PARTITION BY a ORDER BY b) + """)) + .matches(""" + VALUES + (1, ARRAY[1]), + (1, ARRAY[2, 1]), + (1, ARRAY[3, 2, 1]), + (1, ARRAY[7, 3, 2, 4, 1]), + (1, ARRAY[7, 3, 2, 4, 1]), + (1, ARRAY[7, 5, 3, 2, 4, 1]) + """); + + // multiple aggregate functions + assertThat(assertions.query(""" + SELECT + a, + ARRAY_AGG(c ORDER BY c) OVER w, + ARRAY_AGG(c ORDER BY c DESC) OVER w, + ARRAY_AGG(c ORDER BY d) OVER w + FROM ( + VALUES (1, 1, 1, 5), (1, 2, 2, 4), (1, 3, 3, 1), (1, 4, 4, 2), (1, 4, 7, 6) , (1, 5, 5, 3), + (2, 1, 1, 4), (2, 2, 3, 3), (2, 3, 2, 2), (2, 4, 1, 1) + ) AS t(a, b, c, d) + WINDOW w AS (PARTITION BY a ORDER BY b) + """)) + .matches(""" + VALUES + (1, ARRAY[1], ARRAY[1], ARRAY[1]), + (1, ARRAY[1, 2], ARRAY[2, 1], ARRAY[2, 1]), + (1, ARRAY[1, 2, 3], ARRAY[3, 2, 1], ARRAY[3, 2, 1]), + (1, ARRAY[1, 2, 3, 4, 7], ARRAY[7, 4, 3, 2, 1], ARRAY[3, 4, 2, 1, 7]), + (1, ARRAY[1, 2, 3, 4, 7], ARRAY[7, 4, 3, 2, 1], ARRAY[3, 4, 2, 1, 7]), + (1, ARRAY[1, 2, 3, 4, 5, 7], ARRAY[7, 5, 4, 3, 2, 1], ARRAY[3, 4, 5, 2, 1, 7]), + (2, ARRAY[1], ARRAY[1], ARRAY[1]), + (2, ARRAY[1, 3], ARRAY[3, 1], ARRAY[3, 1]), + (2, ARRAY[1, 2, 3], ARRAY[3, 2, 1], ARRAY[2, 3, 1]), + (2, ARRAY[1, 1, 2, 3], ARRAY[3, 2, 1, 1], ARRAY[1, 2, 3, 1]) + """); + } }