diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/AggregationFunctionProvider.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/AggregationFunctionProvider.java index 9544d31855ab2e..85888e21be64a7 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/AggregationFunctionProvider.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/AggregationFunctionProvider.java @@ -18,6 +18,7 @@ import com.google.common.collect.ImmutableMap; import io.trino.spi.connector.SortOrder; import io.trino.sql.ir.Expression; +import io.trino.sql.ir.Reference; import io.trino.sql.planner.OrderingScheme; import io.trino.sql.planner.Symbol; @@ -25,7 +26,6 @@ import java.util.Optional; import static io.trino.sql.planner.assertions.PlanMatchPattern.toSymbolReferences; -import static io.trino.type.UnknownType.UNKNOWN; import static java.lang.String.format; import static java.util.Objects.requireNonNull; @@ -69,7 +69,8 @@ public AggregationFunction getExpectedValue(SymbolAliases aliases) ImmutableMap.Builder orders = ImmutableMap.builder(); for (PlanMatchPattern.Ordering ordering : this.orderBy) { - Symbol symbol = new Symbol(UNKNOWN, aliases.get(ordering.getField()).name()); + Reference reference = aliases.get(ordering.getField()); + Symbol symbol = new Symbol(reference.type(), reference.name()); fields.add(symbol); orders.put(symbol, ordering.getSortOrder()); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneOrderByInAggregation.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneOrderByInAggregation.java index 0a521d8b7cb8d1..40fe48829dbcd0 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneOrderByInAggregation.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPruneOrderByInAggregation.java @@ -38,7 +38,6 @@ import static io.trino.sql.planner.plan.AggregationNode.Step.SINGLE; import static io.trino.sql.tree.SortItem.NullOrdering.LAST; import static io.trino.sql.tree.SortItem.Ordering.ASCENDING; -import static io.trino.type.UnknownType.UNKNOWN; public class TestPruneOrderByInAggregation extends BaseRuleTest @@ -81,16 +80,16 @@ private AggregationNode buildAggregation(PlanBuilder planBuilder) "avg", ImmutableList.of(new Reference(BIGINT, "input")), new OrderingScheme( - ImmutableList.of(new Symbol(UNKNOWN, "input")), - ImmutableMap.of(new Symbol(UNKNOWN, "input"), SortOrder.ASC_NULLS_LAST))), + ImmutableList.of(new Symbol(BIGINT, "input")), + ImmutableMap.of(new Symbol(BIGINT, "input"), SortOrder.ASC_NULLS_LAST))), ImmutableList.of(BIGINT), mask) .addAggregation(arrayAgg, PlanBuilder.aggregation( "array_agg", ImmutableList.of(new Reference(BIGINT, "input")), new OrderingScheme( - ImmutableList.of(new Symbol(UNKNOWN, "input")), - ImmutableMap.of(new Symbol(UNKNOWN, "input"), SortOrder.ASC_NULLS_LAST))), + ImmutableList.of(new Symbol(BIGINT, "input")), + ImmutableMap.of(new Symbol(BIGINT, "input"), SortOrder.ASC_NULLS_LAST))), ImmutableList.of(BIGINT), mask) .hashSymbol(keyHash)