Skip to content

Commit

Permalink
Support ORDER BY in window functions (WIP)
Browse files Browse the repository at this point in the history
  • Loading branch information
losipiuk committed Oct 26, 2024
1 parent 4529b94 commit e43e7f9
Show file tree
Hide file tree
Showing 30 changed files with 323 additions and 46 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
package io.trino.operator.aggregation;

import com.google.common.collect.ImmutableList;
import io.trino.operator.PagesIndex;
import io.trino.operator.window.PagesWindowIndex;
import io.trino.spi.PageBuilder;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.ValueBlock;
import io.trino.spi.connector.SortOrder;
import io.trino.spi.function.WindowAccumulator;
import io.trino.spi.function.WindowIndex;
import io.trino.spi.type.Type;

import java.util.List;
import java.util.stream.IntStream;

import static com.google.common.base.Preconditions.checkState;
import static java.util.Objects.requireNonNull;

public class OrderingWindowAccumulator
implements WindowAccumulator
{
PagesIndex.Factory pagesIndexFactory;
private WindowAccumulator delegate;
private final WindowAccumulator initialDelegate;
private final List<Type> argumentTypes;
private final List<Integer> argumentChannels;
private final List<Integer> sortChannels;
private final List<SortOrder> sortOrders;

private PageBuilder pageBuilder;
private final PagesIndex pagesIndex;
private boolean pagesIndexSorted;

public OrderingWindowAccumulator(
PagesIndex.Factory pagesIndexFactory,
WindowAccumulator delegate,
List<Type> argumentTypes,
List<Integer> argumentChannels,
List<SortOrder> sortOrders)
{
this(pagesIndexFactory, delegate, argumentTypes, argumentChannels, sortOrders, pagesIndexFactory.newPagesIndex(argumentTypes, 10_000));
}

private OrderingWindowAccumulator(
PagesIndex.Factory pagesIndexFactory,
WindowAccumulator delegate,
List<Type> argumentTypes,
List<Integer> argumentChannels,
List<SortOrder> sortOrders,
PagesIndex pagesIndex)
{
this.pagesIndexFactory = requireNonNull(pagesIndexFactory, "pagesIndexFactory is null");
this.delegate = requireNonNull(delegate, "delegate is null");
this.initialDelegate = delegate.copy();
this.argumentTypes = requireNonNull(argumentTypes, "argumentTypes is null");
this.argumentChannels = ImmutableList.copyOf(argumentChannels);
this.sortOrders = ImmutableList.copyOf(sortOrders);
this.pagesIndex = requireNonNull(pagesIndex, "pagesIndex is null");
this.sortChannels = IntStream.range(argumentTypes.size() - sortOrders.size(), argumentChannels.size())
.boxed()
.toList();
resetPageBuilder();
}

private void resetPageBuilder()
{
pageBuilder = new PageBuilder(argumentTypes);
}

@Override
public long getEstimatedSize()
{
return delegate.getEstimatedSize() + initialDelegate.getEstimatedSize() + pagesIndex.getEstimatedSize().toBytes() + pageBuilder.getRetainedSizeInBytes();
}

@Override
public WindowAccumulator copy()
{
PagesIndex pagesIndexCopy = pagesIndexFactory.newPagesIndex(argumentTypes, pagesIndex.getPositionCount());
pagesIndex.getPages().forEachRemaining(pagesIndexCopy::addPage);
return new OrderingWindowAccumulator(pagesIndexFactory, delegate.copy(), argumentTypes, argumentChannels, sortOrders, pagesIndexCopy);
}

@Override
public void addInput(WindowIndex index, int startPosition, int endPosition)
{
if (pagesIndexSorted) {
pagesIndexSorted = false;
// operate on delegate as of start
// nicer would be to add reset() method to WindowAccumulator but it requires reset method in each AccumulatorState class
delegate = initialDelegate.copy();
}
// index is remapped so just go from 0 to argumentChannels.size()
for (int position = startPosition; position <= endPosition; position++) {
if (pageBuilder.isFull()) {
pagesIndex.addPage(pageBuilder.build());
resetPageBuilder();
}
for (int channel = 0; channel < argumentChannels.size(); channel++) {
ValueBlock value = index.getSingleValueBlock(channel, position).getSingleValueBlock(0);
pageBuilder.getBlockBuilder(channel).append(value, 0);
}
pageBuilder.declarePosition();
}
}

@Override
public void output(BlockBuilder blockBuilder)
{
if (!pagesIndexSorted) {
if (!pageBuilder.isEmpty()) {
pagesIndex.addPage(pageBuilder.build());
resetPageBuilder();
}
int positionCount = pagesIndex.getPositionCount();
if (positionCount == 0) {
return;
}
pagesIndex.sort(sortChannels, sortOrders);
WindowIndex sortedWindowIndex = new PagesWindowIndex(pagesIndex, 0, positionCount);
delegate.addInput(sortedWindowIndex, 0, positionCount - 1);
pagesIndexSorted = true;
}
checkState(pageBuilder.isEmpty());

delegate.output(blockBuilder);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import static java.lang.Math.min;
import static java.util.Objects.requireNonNull;

class AggregateWindowFunction
public class AggregateWindowFunction
implements WindowFunction
{
private final Supplier<WindowAccumulator> accumulatorFactory;
Expand Down Expand Up @@ -63,7 +63,6 @@ else if ((frameStart == currentStart) && (frameEnd >= currentEnd)) {
else {
buildNewFrame(frameStart, frameEnd);
}

accumulator.output(output);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,11 @@ public static <T extends Expression> List<T> extractExpressions(

private static boolean isAggregation(FunctionCall functionCall, Session session, FunctionResolver functionResolver, AccessControl accessControl)
{
// return ((functionResolver.isAggregationFunction(session, functionCall.getName(), accessControl) || functionCall.getFilter().isPresent())
// && functionCall.getWindow().isEmpty())
// || functionCall.getOrderBy().isPresent();
return ((functionResolver.isAggregationFunction(session, functionCall.getName(), accessControl) || functionCall.getFilter().isPresent())
&& functionCall.getWindow().isEmpty())
|| functionCall.getOrderBy().isPresent();
&& functionCall.getWindow().isEmpty());
}

private static boolean isWindow(FunctionCall functionCall, Session session, FunctionResolver functionResolver, AccessControl accessControl)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4334,9 +4334,9 @@ private List<FunctionCall> analyzeWindowFunctions(QuerySpecification node, List<
throw semanticException(NOT_SUPPORTED, node, "FILTER is not yet supported for window functions");
}

if (windowFunction.getOrderBy().isPresent()) {
throw semanticException(NOT_SUPPORTED, windowFunction, "Window function with ORDER BY is not supported");
}
// if (windowFunction.getOrderBy().isPresent()) {
// throw semanticException(NOT_SUPPORTED, windowFunction, "Window function with ORDER BY is not supported");
// }

List<Expression> nestedWindowExpressions = extractWindowExpressions(windowFunction.getArguments());
if (!nestedWindowExpressions.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
import io.trino.operator.aggregation.AggregatorFactory;
import io.trino.operator.aggregation.DistinctAccumulatorFactory;
import io.trino.operator.aggregation.OrderedAccumulatorFactory;
import io.trino.operator.aggregation.OrderingWindowAccumulator;
import io.trino.operator.aggregation.partial.PartialAggregationController;
import io.trino.operator.exchange.LocalExchange;
import io.trino.operator.exchange.LocalExchangeSinkOperator.LocalExchangeSinkOperatorFactory;
Expand Down Expand Up @@ -134,6 +135,7 @@
import io.trino.operator.project.PageProcessor;
import io.trino.operator.project.PageProjection;
import io.trino.operator.unnest.UnnestOperator;
import io.trino.operator.window.AggregateWindowFunction;
import io.trino.operator.window.AggregationWindowFunctionSupplier;
import io.trino.operator.window.FrameInfo;
import io.trino.operator.window.PartitionerSupplier;
Expand Down Expand Up @@ -171,6 +173,7 @@
import io.trino.spi.function.CatalogSchemaFunctionName;
import io.trino.spi.function.FunctionId;
import io.trino.spi.function.FunctionKind;
import io.trino.spi.function.WindowFunction;
import io.trino.spi.function.WindowFunctionSupplier;
import io.trino.spi.function.table.TableFunctionProcessorProvider;
import io.trino.spi.predicate.Domain;
Expand Down Expand Up @@ -271,6 +274,7 @@
import io.trino.sql.relational.SqlToRowExpressionTranslator;
import io.trino.type.BlockTypeOperators;
import io.trino.type.FunctionType;
import org.jetbrains.annotations.NotNull;
import org.objectweb.asm.MethodTooLargeException;

import java.util.AbstractMap.SimpleEntry;
Expand Down Expand Up @@ -1114,7 +1118,6 @@ public PhysicalOperation visitWindow(WindowNode node, LocalExecutionPlanContext
Optional<Integer> sortKeyChannelForEndComparison = Optional.empty();
Optional<Integer> sortKeyChannel = Optional.empty();
Optional<FrameInfo.Ordering> ordering = Optional.empty();

Frame frame = entry.getValue().getFrame();
if (frame.getStartValue().isPresent()) {
frameStartChannel = Optional.of(source.getLayout().get(frame.getStartValue().get()));
Expand Down Expand Up @@ -1145,15 +1148,15 @@ public PhysicalOperation visitWindow(WindowNode node, LocalExecutionPlanContext

WindowNode.Function function = entry.getValue();
ResolvedFunction resolvedFunction = function.getResolvedFunction();
ImmutableList.Builder<Integer> arguments = ImmutableList.builder();
ArrayList<Integer> argumentChannels = new ArrayList<>();
for (Expression argument : function.getArguments()) {
if (!(argument instanceof Lambda)) {
Symbol argumentSymbol = Symbol.from(argument);
arguments.add(source.getLayout().get(argumentSymbol));
argumentChannels.add(source.getLayout().get(argumentSymbol));
}
}
Symbol symbol = entry.getKey();
WindowFunctionSupplier windowFunctionSupplier = getWindowFunctionImplementation(resolvedFunction);

Type type = resolvedFunction.signature().getReturnType();

List<Lambda> lambdas = function.getArguments().stream()
Expand All @@ -1165,8 +1168,32 @@ public PhysicalOperation visitWindow(WindowNode node, LocalExecutionPlanContext
.map(FunctionType.class::cast)
.collect(toImmutableList());

WindowFunctionSupplier windowFunctionSupplier;
if (function.getOrderingScheme().isPresent()) {
OrderingScheme orderingScheme = function.getOrderingScheme().orElseThrow();
List<Symbol> sortKeys = orderingScheme.orderBy();
List<SortOrder> sortOrders = sortKeys.stream()
.map(orderingScheme::ordering)
.collect(toImmutableList());

List<Type> argumentTypes = argumentChannels.stream()
.map(channel -> source.getTypes().get(channel))
.collect(toImmutableList());

windowFunctionSupplier = getOrderedWindowFunctionImplementation(
resolvedFunction,
argumentTypes,
argumentChannels,
sortOrders);
}
else {
windowFunctionSupplier = getWindowFunctionImplementation(resolvedFunction);
}

List<Supplier<Object>> lambdaProviders = makeLambdaProviders(lambdas, windowFunctionSupplier.getLambdaInterfaces(), functionTypes);
windowFunctionsBuilder.add(window(windowFunctionSupplier, type, frameInfo, function.isIgnoreNulls(), lambdaProviders, arguments.build()));
WindowFunctionDefinition windowFunction = window(windowFunctionSupplier, type, frameInfo, function.isIgnoreNulls(), lambdaProviders, ImmutableList.copyOf(argumentChannels));

windowFunctionsBuilder.add(windowFunction);
windowFunctionOutputSymbolsBuilder.add(symbol);
}

Expand Down Expand Up @@ -1210,17 +1237,47 @@ public PhysicalOperation visitWindow(WindowNode node, LocalExecutionPlanContext
private WindowFunctionSupplier getWindowFunctionImplementation(ResolvedFunction resolvedFunction)
{
if (resolvedFunction.functionKind() == FunctionKind.AGGREGATE) {
return uncheckedCacheGet(aggregationWindowFunctionSupplierCache, new FunctionKey(resolvedFunction.functionId(), resolvedFunction.signature()), () -> {
AggregationImplementation aggregationImplementation = plannerContext.getFunctionManager().getAggregationImplementation(resolvedFunction);
return new AggregationWindowFunctionSupplier(
resolvedFunction.signature(),
aggregationImplementation,
resolvedFunction.functionNullability());
});
return getAggregationWindowFunctionSupplier(resolvedFunction);
}
return plannerContext.getFunctionManager().getWindowFunctionSupplier(resolvedFunction);
}

private @NotNull AggregationWindowFunctionSupplier getAggregationWindowFunctionSupplier(ResolvedFunction resolvedFunction)
{
checkArgument(resolvedFunction.functionKind() == FunctionKind.AGGREGATE);
return uncheckedCacheGet(aggregationWindowFunctionSupplierCache, new FunctionKey(resolvedFunction.functionId(), resolvedFunction.signature()), () -> {
AggregationImplementation aggregationImplementation = plannerContext.getFunctionManager().getAggregationImplementation(resolvedFunction);
return new AggregationWindowFunctionSupplier(
resolvedFunction.signature(),
aggregationImplementation,
resolvedFunction.functionNullability());
});
}

private WindowFunctionSupplier getOrderedWindowFunctionImplementation(
ResolvedFunction resolvedFunction,
List<Type> argumentTypes,
List<Integer> argumentChannels,
List<SortOrder> sortOrders)
{
AggregationWindowFunctionSupplier aggregationWindowFunctionSupplier = getAggregationWindowFunctionSupplier(resolvedFunction);
return new WindowFunctionSupplier() {
@Override
public WindowFunction createWindowFunction(boolean ignoreNulls, List<Supplier<Object>> lambdaProviders)
{
AggregationImplementation aggregationImplementation = plannerContext.getFunctionManager().getAggregationImplementation(resolvedFunction);
boolean hasRemoveInput = aggregationImplementation.getWindowAccumulator().isPresent();
return new AggregateWindowFunction(() -> new OrderingWindowAccumulator(pagesIndexFactory, aggregationWindowFunctionSupplier.createWindowAccumulator(lambdaProviders), argumentTypes, argumentChannels, sortOrders), hasRemoveInput);
}

@Override
public List<Class<?>> getLambdaInterfaces()
{
return aggregationWindowFunctionSupplier.getLambdaInterfaces();
}
};
}

@Override
public PhysicalOperation visitPatternRecognition(PatternRecognitionNode node, LocalExecutionPlanContext context)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ public RelationPlan planExpand(Query query)
NodeAndMappings checkConvergenceStep = copy(recursionStep, mappings);
Symbol countSymbol = symbolAllocator.newSymbol("count", BIGINT);
ResolvedFunction function = plannerContext.getMetadata().resolveBuiltinFunction("count", ImmutableList.of());
WindowNode.Function countFunction = new WindowNode.Function(function, ImmutableList.of(), DEFAULT_FRAME, false);
WindowNode.Function countFunction = new WindowNode.Function(function, ImmutableList.of(), Optional.empty(), DEFAULT_FRAME, false);

WindowNode windowNode = new WindowNode(
idAllocator.getNextId(),
Expand Down Expand Up @@ -1457,6 +1457,9 @@ private PlanBuilder planWindowFunctions(Node node, PlanBuilder subPlan, List<io.
inputsBuilder.addAll(windowFunction.getArguments().stream()
.filter(argument -> !(argument instanceof LambdaExpression)) // lambda expression is generated at execution time
.collect(Collectors.toList()));
inputsBuilder.addAll(getSortItemsFromOrderBy(windowFunction.getOrderBy()).stream()
.map(SortItem::getSortKey)
.iterator());
}

List<io.trino.sql.tree.Expression> inputs = inputsBuilder.build();
Expand Down Expand Up @@ -1778,6 +1781,7 @@ private PlanBuilder planWindow(
return coercions.get(argument).toSymbolReference();
})
.collect(toImmutableList()),
windowFunction.getOrderBy().map(orderBy -> translateOrderingScheme(orderBy.getSortItems(), coercions::get)),
frame,
nullTreatment == NullTreatment.IGNORE);

Expand Down Expand Up @@ -1859,6 +1863,7 @@ private PlanBuilder planPatternRecognition(
return coercions.get(argument).toSymbolReference();
})
.collect(toImmutableList()),
Optional.empty(),
baseFrame,
nullTreatment == NullTreatment.IGNORE);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,7 @@ public RewriteResult visitTopN(TopNNode node, Void context)
WindowNode.Function rowNumberFunction = new WindowNode.Function(
metadata.resolveBuiltinFunction("row_number", ImmutableList.of()),
ImmutableList.of(),
Optional.empty(),
DEFAULT_FRAME,
false);
WindowNode windowNode = new WindowNode(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ public static PlanNode rewriteLimitWithTiesWithPartitioning(LimitNode limitNode,
WindowNode.Function rankFunction = new WindowNode.Function(
metadata.resolveBuiltinFunction("rank", ImmutableList.of()),
ImmutableList.of(),
Optional.empty(),
DEFAULT_FRAME,
false);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,8 @@ private static NodeWithSymbols planWindowFunctionsForSource(
source,
specification,
ImmutableMap.of(
rowNumber, new WindowNode.Function(rowNumberFunction, ImmutableList.of(), FULL_FRAME, false),
partitionSize, new WindowNode.Function(countFunction, ImmutableList.of(), FULL_FRAME, false)),
rowNumber, new WindowNode.Function(rowNumberFunction, ImmutableList.of(), Optional.empty(), FULL_FRAME, false),
partitionSize, new WindowNode.Function(countFunction, ImmutableList.of(), Optional.empty(), FULL_FRAME, false)),
Optional.empty(),
ImmutableSet.of(),
0);
Expand Down
Loading

0 comments on commit e43e7f9

Please sign in to comment.