Skip to content

Commit

Permalink
Prune ORDER BY in window aggregation functions
Browse files Browse the repository at this point in the history
  • Loading branch information
losipiuk committed Oct 26, 2024
1 parent 0018b2e commit 74340e7
Show file tree
Hide file tree
Showing 7 changed files with 226 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@
import io.trino.sql.planner.iterative.rule.PruneMergeSourceColumns;
import io.trino.sql.planner.iterative.rule.PruneOffsetColumns;
import io.trino.sql.planner.iterative.rule.PruneOrderByInAggregation;
import io.trino.sql.planner.iterative.rule.PruneOrderByInWindowAggregation;
import io.trino.sql.planner.iterative.rule.PruneOutputSourceColumns;
import io.trino.sql.planner.iterative.rule.PrunePattenRecognitionColumns;
import io.trino.sql.planner.iterative.rule.PrunePatternRecognitionSourceColumns;
Expand Down Expand Up @@ -460,6 +461,7 @@ public PlanOptimizers(
new MergeLimitWithDistinct(),
new PruneCountAggregationOverScalar(metadata),
new PruneOrderByInAggregation(metadata),
new PruneOrderByInWindowAggregation(metadata),
new RewriteSpatialPartitioningAggregation(plannerContext),
new SimplifyCountOverConstant(plannerContext),
new PreAggregateCaseAggregations(plannerContext),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.Maps;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.Metadata;
import io.trino.spi.function.FunctionKind;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.WindowNode;

import java.util.Map;
import java.util.Optional;
import java.util.function.Predicate;

import static io.trino.sql.planner.plan.Patterns.window;
import static java.util.Objects.requireNonNull;

public class PruneOrderByInWindowAggregation
implements Rule<WindowNode>
{
private static final Pattern<WindowNode> PATTERN = window();
private final Metadata metadata;

public PruneOrderByInWindowAggregation(Metadata metadata)
{
this.metadata = requireNonNull(metadata, "metadata is null");
}

@Override
public Pattern<WindowNode> getPattern()
{
return PATTERN;
}

@Override
public Result apply(WindowNode node, Captures captures, Context context)
{
Predicate<WindowNode.Function> isOrderInsensitive = windowFunction -> windowFunction.getResolvedFunction().functionKind() == FunctionKind.AGGREGATE &&
!metadata.getAggregationFunctionMetadata(context.getSession(), windowFunction.getResolvedFunction()).isOrderSensitive();

if (node.getWindowFunctions().values().stream().noneMatch(isOrderInsensitive)) {
return Result.empty();
}

Map<Symbol, WindowNode.Function> prunedFunctions = Maps.transformValues(node.getWindowFunctions(), windowFunction -> {
if (isOrderInsensitive.test(windowFunction)) {
return new WindowNode.Function(
windowFunction.getResolvedFunction(),
windowFunction.getArguments(),
Optional.empty(), // prune
windowFunction.getFrame(),
windowFunction.isIgnoreNulls());
}
return windowFunction;
});

return Result.ofPlanNode(new WindowNode(
node.getId(),
node.getSource(),
node.getSpecification(),
prunedFunctions,
node.getHashSymbol(),
node.getPrePartitionedInputs(),
node.getPreSortedOrderPrefix()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1105,7 +1105,12 @@ public static ExpectedValueProvider<AggregationFunction> aggregationFunction(Str

public static ExpectedValueProvider<WindowFunction> windowFunction(String name, List<String> args, WindowNode.Frame frame)
{
return new WindowFunctionProvider(name, frame, toSymbolAliases(args));
return windowFunction(name, args, frame, ImmutableList.of());
}

public static ExpectedValueProvider<WindowFunction> windowFunction(String name, List<String> args, WindowNode.Frame frame, List<PlanMatchPattern.Ordering> orderBy)
{
return new WindowFunctionProvider(name, frame, toSymbolAliases(args), orderBy);
}

public static List<Expression> toSymbolReferences(List<PlanTestSymbol> aliases, SymbolAliases symbolAliases)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,25 @@

import com.google.common.collect.ImmutableList;
import io.trino.sql.ir.Expression;
import io.trino.sql.planner.OrderingScheme;
import io.trino.sql.planner.plan.WindowNode;

import java.util.List;
import java.util.Optional;

import static java.util.Objects.requireNonNull;

public record WindowFunction(
String name,
WindowNode.Frame frame,
List<Expression> arguments)
List<Expression> arguments,
Optional<OrderingScheme> orderBy)
{
public WindowFunction(String name, WindowNode.Frame frame, List<Expression> arguments)
public WindowFunction(String name, WindowNode.Frame frame, List<Expression> arguments, Optional<OrderingScheme> orderBy)
{
this.name = requireNonNull(name, "name is null");
this.frame = requireNonNull(frame, "frame is null");
this.arguments = ImmutableList.copyOf(arguments);
this.orderBy = requireNonNull(orderBy, "orderBy is null");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import io.trino.sql.planner.plan.WindowNode.Function;

import java.util.Map;
import java.util.Objects;
import java.util.Optional;

import static com.google.common.base.MoreObjects.toStringHelper;
Expand Down Expand Up @@ -70,6 +71,7 @@ private boolean windowFunctionMatches(Function windowFunction, WindowFunction ex
{
return expectedCall.name().equals(windowFunction.getResolvedFunction().signature().getName().getFunctionName()) &&
WindowFrameMatcher.matches(expectedCall.frame(), windowFunction.getFrame(), aliases) &&
Objects.equals(expectedCall.orderBy(), windowFunction.getOrderingScheme()) &&
expectedCall.arguments().equals(windowFunction.getArguments());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,18 @@
package io.trino.sql.planner.assertions;

import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.spi.connector.SortOrder;
import io.trino.sql.planner.OrderingScheme;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.plan.WindowNode;

import java.util.List;
import java.util.Optional;

import static io.trino.sql.planner.assertions.PlanMatchPattern.toSymbolReferences;
import static io.trino.type.UnknownType.UNKNOWN;
import static java.util.Objects.requireNonNull;

final class WindowFunctionProvider
Expand All @@ -27,26 +34,42 @@ final class WindowFunctionProvider
private final String name;
private final WindowNode.Frame frame;
private final List<PlanTestSymbol> args;
private final List<PlanMatchPattern.Ordering> orderBy;

public WindowFunctionProvider(String name, WindowNode.Frame frame, List<PlanTestSymbol> args)
public WindowFunctionProvider(String name, WindowNode.Frame frame, List<PlanTestSymbol> args, List<PlanMatchPattern.Ordering> orderBy)
{
this.name = requireNonNull(name, "name is null");
this.frame = requireNonNull(frame, "frame is null");
this.args = requireNonNull(args, "args is null");
this.orderBy = ImmutableList.copyOf(orderBy);
}

@Override
public String toString()
{
return "%s(%s) %s".formatted(
return "%s(%s%s) %s".formatted(
name,
Joiner.on(", ").join(args),
orderBy.isEmpty() ? "" : " ORDER BY " + Joiner.on(", ").join(orderBy),
frame);
}

@Override
public WindowFunction getExpectedValue(SymbolAliases aliases)
{
return new WindowFunction(name, frame, toSymbolReferences(args, aliases));
Optional<OrderingScheme> orderByClause = Optional.empty();
if (!orderBy.isEmpty()) {
ImmutableList.Builder<Symbol> fields = ImmutableList.builder();
ImmutableMap.Builder<Symbol, SortOrder> orders = ImmutableMap.builder();

for (PlanMatchPattern.Ordering ordering : this.orderBy) {
Symbol symbol = new Symbol(UNKNOWN, aliases.get(ordering.getField()).name());
fields.add(symbol);
orders.put(symbol, ordering.getSortOrder());
}
orderByClause = Optional.of(new OrderingScheme(fields.build(), orders.buildOrThrow()));
}

return new WindowFunction(name, frame, toSymbolReferences(args, aliases), orderByClause);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.metadata.Metadata;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.connector.SortOrder;
import io.trino.sql.ir.Reference;
import io.trino.sql.planner.OrderingScheme;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.planner.plan.DataOrganizationSpecification;
import io.trino.sql.planner.plan.WindowNode;
import org.junit.jupiter.api.Test;

import java.util.List;
import java.util.Optional;

import static io.trino.metadata.MetadataManager.createTestMetadataManager;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes;
import static io.trino.sql.planner.assertions.PlanMatchPattern.sort;
import static io.trino.sql.planner.assertions.PlanMatchPattern.specification;
import static io.trino.sql.planner.assertions.PlanMatchPattern.values;
import static io.trino.sql.planner.assertions.PlanMatchPattern.window;
import static io.trino.sql.planner.assertions.PlanMatchPattern.windowFunction;
import static io.trino.sql.planner.plan.WindowNode.Frame.DEFAULT_FRAME;
import static io.trino.sql.tree.SortItem.NullOrdering.LAST;
import static io.trino.sql.tree.SortItem.Ordering.ASCENDING;
import static io.trino.type.UnknownType.UNKNOWN;

public class TestPruneOrderByInWindowAggregation
extends BaseRuleTest
{
private static final Metadata METADATA = createTestMetadataManager();

@Test
public void testBasics()
{
tester().assertThat(new PruneOrderByInWindowAggregation(METADATA))
.on(this::buildWindowNode)
.matches(
window(
windowMatcherBuilder -> windowMatcherBuilder
.specification(specification(
ImmutableList.of("key"),
ImmutableList.of(),
ImmutableMap.of()))
.addFunction(
"avg",
windowFunction("avg", ImmutableList.of("input"), DEFAULT_FRAME))
.addFunction(
"array_agg",
windowFunction("array_agg", ImmutableList.of("input"), DEFAULT_FRAME, ImmutableList.of(sort("input", ASCENDING, LAST)))),
values("input", "key", "keyHash", "mask")));
}

private WindowNode buildWindowNode(PlanBuilder planBuilder)
{
Symbol avg = planBuilder.symbol("avg");
Symbol arrayAgg = planBuilder.symbol("araray_agg");
Symbol input = planBuilder.symbol("input");
Symbol key = planBuilder.symbol("key");
Symbol keyHash = planBuilder.symbol("keyHash");
Symbol mask = planBuilder.symbol("mask");
List<Symbol> sourceSymbols = ImmutableList.of(input, key, keyHash, mask);

ResolvedFunction avgFunction = createTestMetadataManager().resolveBuiltinFunction("avg", fromTypes(BIGINT));
ResolvedFunction arrayAggFunction = createTestMetadataManager().resolveBuiltinFunction("array_agg", fromTypes(BIGINT));

return planBuilder.window(
new DataOrganizationSpecification(ImmutableList.of(planBuilder.symbol("key", BIGINT)), Optional.empty()),
ImmutableMap.of(
avg, new WindowNode.Function(avgFunction,
ImmutableList.of(new Reference(BIGINT, "input")),
Optional.of(new OrderingScheme(
ImmutableList.of(new Symbol(UNKNOWN, "input")),
ImmutableMap.of(new Symbol(UNKNOWN, "input"), SortOrder.ASC_NULLS_LAST))),
DEFAULT_FRAME,
false),
arrayAgg, new WindowNode.Function(arrayAggFunction,
ImmutableList.of(new Reference(BIGINT, "input")),
Optional.of(new OrderingScheme(
ImmutableList.of(new Symbol(UNKNOWN, "input")),
ImmutableMap.of(new Symbol(UNKNOWN, "input"), SortOrder.ASC_NULLS_LAST))),
DEFAULT_FRAME,
false)),
planBuilder.values(sourceSymbols, ImmutableList.of()));
}
}

0 comments on commit 74340e7

Please sign in to comment.