From d123b863606aa753441e7684c90972f546c84a5c Mon Sep 17 00:00:00 2001 From: Seaven Date: Thu, 24 Oct 2024 17:13:20 +0800 Subject: [PATCH] [BugFix] Fix eliminate join lose predicate (#52273) Signed-off-by: Seaven (cherry picked from commit aeafef0195eff9d44a0e2d945c4795d52ff262a8) --- .../EliminateJoinWithConstantRule.java | 33 +++++++++++-------- .../java/com/starrocks/sql/plan/JoinTest.java | 12 +++++++ 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/EliminateJoinWithConstantRule.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/EliminateJoinWithConstantRule.java index 4d8b951c34b0f..3d8c36c81abca 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/EliminateJoinWithConstantRule.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/EliminateJoinWithConstantRule.java @@ -59,6 +59,10 @@ private EliminateJoinWithConstantRule(int index) { @Override public boolean check(OptExpression input, OptimizerContext context) { if (OperatorType.LOGICAL_PROJECT.equals(input.inputAt(constantIndex).getOp().getOpType())) { + LogicalProjectOperator project = input.inputAt(constantIndex).getOp().cast(); + if (!project.getColumnRefMap().values().stream().allMatch(ScalarOperator::isConstant)) { + return false; + } OptExpression optExpression = input.inputAt(constantIndex); OptExpression valuesOpt = optExpression.inputAt(0); return checkValuesOptExpression(valuesOpt); @@ -106,35 +110,38 @@ public List onMatch(OptExpression joinOpt, OptExpression otherOpt, OptExpression constantOpt, OptimizerContext context) { - Map outputs = Maps.newHashMap(); + LogicalJoinOperator joinOperator = (LogicalJoinOperator) joinOpt.getOp(); - JoinOperator joinType = joinOperator.getJoinType(); + LogicalProjectOperator projectOperator = (LogicalProjectOperator) constantOpt.getOp(); + ScalarOperator condition = joinOperator.getOnPredicate(); - ScalarOperator predicate = otherOpt.getOp().getPredicate(); + ScalarOperator predicate = joinOperator.getPredicate(); + // rewrite join's on-predicate with constant column values - LogicalProjectOperator projectOperator = (LogicalProjectOperator) constantOpt.getOp(); ReplaceColumnRefRewriter rewriter = new ReplaceColumnRefRewriter(projectOperator.getColumnRefMap()); ScalarOperator rewrittenCondition = rewriter.rewrite(condition); + ScalarOperator rewrittenPredicate = rewriter.rewrite(predicate); + // output join and constant opt's output columns + Map outputs = Maps.newHashMap(); joinOpt.getOutputColumns().getStream().map(context.getColumnRefFactory()::getColumnRef) .forEach(ref -> outputs.put(ref, rewriter.rewrite(ref))); if (joinOperator.getJoinType().isOuterJoin()) { // transform join's on-predicate with case-when operator - constantOpt.getRowOutputInfo().getColumnRefMap().entrySet().stream() - .forEach(entry -> { - ScalarOperator transformed = transformOuterJoinOnPredicate( - joinOperator, entry.getValue(), rewrittenCondition); - outputs.put(entry.getKey(), transformed); - }); + constantOpt.getRowOutputInfo().getColumnRefMap().forEach((key, value) -> { + ScalarOperator t = transformOuterJoinOnPredicate(joinOperator, value, rewrittenCondition); + outputs.put(key, t); + }); } else { - predicate = Utils.compoundAnd(predicate, rewrittenCondition); + rewrittenPredicate = Utils.compoundAnd(rewrittenPredicate, rewrittenCondition); } + LogicalProjectOperator project = new LogicalProjectOperator(outputs); OptExpression result = OptExpression.create(project, otherOpt); // save predicate - if (predicate != null) { - result = OptExpression.create(new LogicalFilterOperator(predicate), result); + if (rewrittenPredicate != null) { + result = OptExpression.create(new LogicalFilterOperator(rewrittenPredicate), result); } return Lists.newArrayList(result); } diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/plan/JoinTest.java b/fe/fe-core/src/test/java/com/starrocks/sql/plan/JoinTest.java index 96515cfff4ee8..5133b11fce2f8 100644 --- a/fe/fe-core/src/test/java/com/starrocks/sql/plan/JoinTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/sql/plan/JoinTest.java @@ -3290,4 +3290,16 @@ public void testOuterJoinOnConstValue() throws Exception { " | join op: RIGHT OUTER JOIN\n" + " | colocate: false, reason:"); } + + @Test + public void testJoinOnConstValue() throws Exception { + String query = "select coalesce(b.v1, a.v1) as v1, a.v2 \n" + + "from t0 a left join (select 'cccc' as v1, 'dddd' as v2) b on a.v1 = b.v1 \n" + + "where coalesce(b.v1, a.v1) = '1';"; + String plan = getFragmentPlan(query); + assertContainsIgnoreColRefs(plan, " 0:OlapScanNode\n" + + " TABLE: t0\n" + + " PREAGGREGATION: ON\n" + + " PREDICATES: coalesce('cccc', CAST(1: v1 AS VARCHAR)) = '1'"); + } }