Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test: Add varchar/char cast pushdown support in Redshift #23936

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ jobs:
strategy:
fail-fast: false
matrix: ${{ fromJson(needs.build-test-matrix.outputs.matrix) }}
timeout-minutes: 60
timeout-minutes: 180
steps:
- uses: actions/checkout@v4
with:
Expand Down
1 change: 1 addition & 0 deletions plugin/trino-redshift/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@
<!-- JDBC operations performed on the ephemeral AWS Redshift cluster. -->
<include>**/TestRedshiftCastPushdown.java</include>
<include>**/TestRedshiftConnectorSmokeTest.java</include>
<include>**/TestRedshiftConnectorTest.java</include>
</includes>
</configuration>
</plugin>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@
import static io.trino.plugin.jdbc.StandardColumnMappings.varcharWriteFunction;
import static io.trino.plugin.jdbc.TypeHandlingJdbcSessionProperties.getUnsupportedTypeHandling;
import static io.trino.plugin.jdbc.UnsupportedTypeHandling.CONVERT_TO_VARCHAR;
import static io.trino.plugin.redshift.RedshiftErrorCode.REDSHIFT_INVALID_TYPE;
import static io.trino.spi.StandardErrorCode.INVALID_ARGUMENTS;
import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED;
import static io.trino.spi.type.BigintType.BIGINT;
Expand Down Expand Up @@ -666,10 +665,9 @@ public Optional<ColumnMapping> toColumnMapping(ConnectorSession session, Connect
RedshiftClient::writeChar));

case Types.VARCHAR: {
if (type.columnSize().isEmpty()) {
throw new TrinoException(REDSHIFT_INVALID_TYPE, "column size not present");
}
int length = type.requiredColumnSize();
// Redshift column exposes precision with max precision upto `REDSHIFT_MAX_VARCHAR`.
// Defaulting to `VarcharType.MAX_LENGTH`(instead of `REDSHIFT_MAX_VARCHAR`) as synthetic column created by Trino creates unbounded varchar.
int length = type.columnSize().orElse(VarcharType.MAX_LENGTH);
return Optional.of(varcharColumnMapping(
length < VarcharType.MAX_LENGTH
? createVarcharType(length)
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,24 @@
import io.trino.plugin.jdbc.expression.AbstractRewriteCast;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.CharType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.SmallintType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;

import java.util.List;
import java.util.Optional;
import java.util.function.BiFunction;

import static java.sql.Types.BIGINT;
import static java.sql.Types.BIT;
import static java.sql.Types.CHAR;
import static java.sql.Types.INTEGER;
import static java.sql.Types.NUMERIC;
import static java.sql.Types.SMALLINT;
import static java.sql.Types.VARCHAR;

public class RewriteCast
extends AbstractRewriteCast
Expand All @@ -57,6 +61,10 @@ protected Optional<JdbcTypeHandle> toJdbcTypeHandle(JdbcTypeHandle sourceType, T
Optional.of(new JdbcTypeHandle(INTEGER, Optional.of(integerType.getBaseName()), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()));
case BigintType bigintType ->
Optional.of(new JdbcTypeHandle(BIGINT, Optional.of(bigintType.getBaseName()), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()));
case VarcharType varcharType ->
Optional.of(new JdbcTypeHandle(VARCHAR, Optional.of(varcharType.getBaseName()), varcharType.getLength(), Optional.empty(), Optional.empty(), Optional.empty()));
case CharType charType ->
Optional.of(new JdbcTypeHandle(CHAR, Optional.of(charType.getBaseName()), Optional.of(charType.getLength()), Optional.empty(), Optional.empty(), Optional.empty()));
default -> Optional.empty();
};
}
Expand All @@ -66,6 +74,10 @@ private boolean pushdownSupported(JdbcTypeHandle sourceType, Type targetType)
return switch (targetType) {
case SmallintType _, IntegerType _, BigintType _ ->
SUPPORTED_SOURCE_TYPE_FOR_INTEGRAL_CAST.contains(sourceType.jdbcType());
// char -> varchar is not supported as Redshift doesn't pad char value with blanks whereas Trino pads char value with blanks.
case VarcharType _ -> VARCHAR == sourceType.jdbcType();
// varchar -> char is unsupported as varchar supports multi-byte characters whereas char supports only single byte characters.
case CharType _ -> CHAR == sourceType.jdbcType();
default -> false;
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.testing.QueryRunner;
import io.trino.testing.sql.SqlExecutor;
import io.trino.testing.sql.TestTable;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

Expand Down Expand Up @@ -79,16 +80,25 @@ public void setupTable()
.addColumn("c_numeric_10_2", "numeric(10, 2)", asList(1.23, 2.67, null))
.addColumn("c_numeric_19_2", "numeric(19, 2)", asList(1.23, 2.67, null)) // Equal to REDSHIFT_DECIMAL_CUTOFF_PRECISION
.addColumn("c_numeric_30_2", "numeric(30, 2)", asList(1.23, 2.67, null))
.addColumn("c_char", "char", asList("'I'", "'P'", null))
.addColumn("c_character", "character", asList("'I'", "'P'", null))
.addColumn("c_char_10", "char(10)", asList("'India'", "'Poland'", null))
.addColumn("c_char_50", "char(50)", asList("'India'", "'Poland'", null))
.addColumn("c_char_4096", "char(4096)", asList("'India'", "'Poland'", null)) // Equal to REDSHIFT_MAX_CHAR

// the number of Unicode code points in 攻殻機動隊 is 5, and in 😂 is 1.
.addColumn("c_varchar_unicode", "varchar(15)", asList("'攻殻機動隊'", "'😂'", null))
.addColumn("c_nvarchar_unicode", "nvarchar(15)", asList("'攻殻機動隊'", "'😂'", null))

.addColumn("c_nchar", "nchar", asList("'I'", "'P'", null))
.addColumn("c_nchar_10", "nchar(10)", asList("'India'", "'Poland'", null))
.addColumn("c_nchar_50", "nchar(50)", asList("'India'", "'Poland'", null))
.addColumn("c_nchar_4096", "nchar(4096)", asList("'India'", "'Poland'", null)) // Equal to REDSHIFT_MAX_CHAR
.addColumn("c_bpchar", "bpchar", asList("'India'", "'Poland'", null))
.addColumn("c_varchar_10", "varchar(10)", asList("'India'", "'Poland'", null))
.addColumn("c_varchar_50", "varchar(50)", asList("'India'", "'Poland'", null))
.addColumn("c_varchar_65535", "varchar(65535)", asList("'India'", "'Poland'", null)) // Equal to REDSHIFT_MAX_VARCHAR
.addColumn("c_nvarchar", "nvarchar", asList("'India'", "'Poland'", null))
.addColumn("c_nvarchar_10", "nvarchar(10)", asList("'India'", "'Poland'", null))
.addColumn("c_nvarchar_50", "nvarchar(50)", asList("'India'", "'Poland'", null))
.addColumn("c_nvarchar_65535", "nvarchar(65535)", asList("'India'", "'Poland'", null)) // Greater than REDSHIFT_MAX_VARCHAR
Expand Down Expand Up @@ -144,16 +154,25 @@ public void setupTable()
.addColumn("c_numeric_10_2", "numeric(10, 2)", asList(1.23, 22.67, null))
.addColumn("c_numeric_19_2", "numeric(19, 2)", asList(1.23, 22.67, null)) // Equal to REDSHIFT_DECIMAL_CUTOFF_PRECISION
.addColumn("c_numeric_30_2", "numeric(30, 2)", asList(1.23, 22.67, null))
.addColumn("c_char", "char", asList("'I'", "'F'", null))
.addColumn("c_character", "character", asList("'I'", "'F'", null))
.addColumn("c_char_10", "char(10)", asList("'India'", "'France'", null))
.addColumn("c_char_50", "char(50)", asList("'India'", "'France'", null))
.addColumn("c_char_4096", "char(4096)", asList("'India'", "'France'", null)) // Equal to REDSHIFT_MAX_CHAR

// the number of Unicode code points in 攻殻機動隊 is 5, and in 😂 is 1.
.addColumn("c_varchar_unicode", "varchar(15)", asList("'攻殻機動隊'", "'😂'", null))
.addColumn("c_nvarchar_unicode", "nvarchar(15)", asList("'攻殻機動隊'", "'😂'", null))

.addColumn("c_nchar", "nchar", asList("'I'", "'F'", null))
.addColumn("c_nchar_10", "nchar(10)", asList("'India'", "'France'", null))
.addColumn("c_nchar_50", "nchar(50)", asList("'India'", "'France'", null))
.addColumn("c_nchar_4096", "nchar(4096)", asList("'India'", "'France'", null)) // Equal to REDSHIFT_MAX_CHAR
.addColumn("c_bpchar", "bpchar", asList("'India'", "'France'", null))
.addColumn("c_varchar_10", "varchar(10)", asList("'India'", "'France'", null))
.addColumn("c_varchar_50", "varchar(50)", asList("'India'", "'France'", null))
.addColumn("c_varchar_65535", "varchar(65535)", asList("'India'", "'France'", null)) // Equal to REDSHIFT_MAX_VARCHAR
.addColumn("c_nvarchar", "nvarchar", asList("'India'", "'France'", null))
.addColumn("c_nvarchar_10", "nvarchar(10)", asList("'India'", "'France'", null))
.addColumn("c_nvarchar_50", "nvarchar(50)", asList("'India'", "'France'", null))
.addColumn("c_nvarchar_65535", "nvarchar(65535)", asList("'India'", "'France'", null)) // Equal to REDSHIFT_MAX_VARCHAR
Expand Down Expand Up @@ -243,6 +262,17 @@ public void testAllJoinPushdownWithCast()
// Full Join pushdown is not supported
assertThat(query("SELECT l.id FROM %s l FULL JOIN %s r ON CAST(l.%s AS %s) = r.%s".formatted(leftTable(), rightTable(), testCase.sourceColumn(), testCase.castType(), testCase.targetColumn())))
.joinIsNotFullyPushedDown();

testCase = new CastTestCase("c_varchar_10", "varchar(200)", "c_varchar_50");
assertThat(query("SELECT l.id FROM %s l LEFT JOIN %s r ON CAST(l.%3$s AS %4$s) = CAST(r.%5$s AS %4$s)".formatted(leftTable(), rightTable(), testCase.sourceColumn(), testCase.castType(), testCase.targetColumn())))
.isFullyPushedDown();
assertThat(query("SELECT l.id FROM %s l RIGHT JOIN %s r ON CAST(l.%3$s AS %4$s) = CAST(r.%5$s AS %4$s)".formatted(leftTable(), rightTable(), testCase.sourceColumn(), testCase.castType(), testCase.targetColumn())))
.isFullyPushedDown();
assertThat(query("SELECT l.id FROM %s l INNER JOIN %s r ON CAST(l.%3$s AS %4$s) = CAST(r.%5$s AS %4$s)".formatted(leftTable(), rightTable(), testCase.sourceColumn(), testCase.castType(), testCase.targetColumn())))
.isFullyPushedDown();
// Full Join pushdown is not supported
assertThat(query("SELECT l.id FROM %s l FULL JOIN %s r ON CAST(l.%3$s AS %4$s) = CAST(r.%5$s AS %4$s)".formatted(leftTable(), rightTable(), testCase.sourceColumn(), testCase.castType(), testCase.targetColumn())))
.joinIsNotFullyPushedDown();
}

@Test
Expand Down Expand Up @@ -364,6 +394,16 @@ public void testCastPushdownWithForcedTypedToInteger()
.isNotFullyPushedDown(ProjectNode.class);
}

@Test
void testCastPushdownWithForcedTypedToVarchar()
{
// These column types are not supported by default by trino. These types are forced mapped to varchar.
assertThat(query("SELECT CAST(c_timetz AS VARCHAR(100)) FROM %s".formatted(leftTable())))
.isNotFullyPushedDown(ProjectNode.class);
assertThat(query("SELECT CAST(c_super AS VARCHAR(100)) FROM %s".formatted(leftTable())))
.isNotFullyPushedDown(ProjectNode.class);
}

@Override
protected List<CastTestCase> supportedCastTypePushdown()
{
Expand Down Expand Up @@ -393,6 +433,41 @@ protected List<CastTestCase> supportedCastTypePushdown()
new CastTestCase("c_numeric_30_2", "integer", "c_integer"),
new CastTestCase("c_decimal_negative", "integer", "c_integer"),

new CastTestCase("c_char_10", "char(50)", "c_char_50"),
new CastTestCase("c_char_10", "char(256)", "c_bpchar"),
new CastTestCase("c_varchar_10", "varchar(50)", "c_varchar_50"),
new CastTestCase("c_nvarchar_10", "varchar(50)", "c_nvarchar_50"),
new CastTestCase("c_varchar_10", "varchar(50)", "c_text"),

new CastTestCase("c_char_50", "char(10)", "c_char_10"),
new CastTestCase("c_bpchar", "char(10)", "c_char_10"),
new CastTestCase("c_varchar_50", "varchar(10)", "c_varchar_10"),
new CastTestCase("c_nvarchar_50", "varchar(10)", "c_nvarchar_10"),
new CastTestCase("c_text", "varchar(10)", "c_varchar_10"),

new CastTestCase("c_char_10", "char(50)", "c_char_50"),
new CastTestCase("c_char_10", "char(256)", "c_bpchar"),
new CastTestCase("c_char", "char(4096)", "c_char_4096"),
new CastTestCase("c_char", "char(1)", "c_nchar"),
new CastTestCase("c_varchar_10", "varchar(50)", "c_varchar_50"),
new CastTestCase("c_nvarchar_10", "varchar(50)", "c_nvarchar_50"),
new CastTestCase("c_varchar_10", "varchar(50)", "c_text"),

new CastTestCase("c_varchar_50", "varchar(10)", "c_varchar_10"),

new CastTestCase("c_varchar_10", "varchar(50)", "c_varchar_50"),
new CastTestCase("c_varchar_10", "varchar(65535)", "c_varchar_65535"),
new CastTestCase("c_varchar_10", "varchar(256)", "c_nvarchar"),
new CastTestCase("c_varchar_10", "varchar(10)", "c_nvarchar_10"),
new CastTestCase("c_varchar_10", "varchar(50)", "c_nvarchar_50"),
new CastTestCase("c_varchar_10", "varchar(65535)", "c_nvarchar_65535"),
new CastTestCase("c_varchar_10", "varchar(256)", "c_text"),
new CastTestCase("c_varchar_10", "varchar", "c_text"),

new CastTestCase("c_varchar_unicode", "varchar", "c_varchar_50"),
new CastTestCase("c_varchar_unicode", "varchar(50)", "c_varchar_50"),
new CastTestCase("c_nvarchar_unicode", "varchar(50)", "c_varchar_50"),

new CastTestCase("c_boolean", "bigint", "c_bigint"),
new CastTestCase("c_smallint", "bigint", "c_bigint"),
new CastTestCase("c_integer", "bigint", "c_bigint"),
Expand Down Expand Up @@ -460,11 +535,30 @@ protected List<CastTestCase> unsupportedCastTypePushdown()
new CastTestCase("c_real", "double", "c_double_precision"),
new CastTestCase("c_double_precision", "real", "c_real"),
new CastTestCase("c_double_precision", "decimal(10,2)", "c_decimal_10_2"),
new CastTestCase("c_char_10", "char(50)", "c_char_50"),
new CastTestCase("c_char_10", "char(256)", "c_bpchar"),
new CastTestCase("c_varchar_10", "varchar(50)", "c_varchar_50"),
new CastTestCase("c_nvarchar_10", "varchar(50)", "c_nvarchar_50"),
new CastTestCase("c_varchar_10", "varchar(50)", "c_text"),

new CastTestCase("c_varchar_50", "char(50)", "c_char_50"),
new CastTestCase("c_char_50", "varchar(50)", "c_varchar_50"),
new CastTestCase("c_boolean", "varchar(50)", "c_varchar_50"),
new CastTestCase("c_smallint", "varchar(50)", "c_varchar_50"),
new CastTestCase("c_int2", "varchar(50)", "c_varchar_50"),
new CastTestCase("c_integer", "varchar(50)", "c_varchar_50"),
new CastTestCase("c_int", "varchar(50)", "c_varchar_50"),
new CastTestCase("c_int4", "varchar(50)", "c_varchar_50"),
new CastTestCase("c_bigint", "varchar(50)", "c_varchar_50"),
new CastTestCase("c_int8", "varchar(50)", "c_varchar_50"),
new CastTestCase("c_real", "varchar(50)", "c_varchar_50"),
new CastTestCase("c_float4", "varchar(50)", "c_varchar_50"),
new CastTestCase("c_double_precision", "varchar(50)", "c_varchar_50"),
new CastTestCase("c_float", "varchar(50)", "c_varchar_50"),
new CastTestCase("c_float8", "varchar(50)", "c_varchar_50"),
new CastTestCase("c_double_precision", "varchar(50)", "c_varchar_50"),

new CastTestCase("c_timestamp", "varchar(50)", "c_varchar_50"),
new CastTestCase("c_date", "varchar(50)", "c_varchar_50"),

new CastTestCase("c_varchar_unicode", "char(50)", "c_char_50"),
new CastTestCase("c_nvarchar_unicode", "char(50)", "c_char_50"),

new CastTestCase("c_timestamp", "date", "c_date"),
new CastTestCase("c_timestamp", "time", "c_time"),
new CastTestCase("c_date", "timestamp", "c_timestamp"),
Expand All @@ -491,4 +585,19 @@ protected List<InvalidCastTestCase> invalidCast()
new InvalidCastTestCase("c_timetz", "int"),
new InvalidCastTestCase("c_timetz", "bigint"));
}

@Test
void testCastPushdownWithCharConvertedToVarchar()
{
try (TestTable table = new TestTable(
getQueryRunner()::execute,
TEST_SCHEMA + "." + "case_sensitive_1_",
"(a char(4097))", // char(REDSHIFT_MAX_CHAR` + 1) is converted to varchar(REDSHIFT_MAX_CHAR` + 1) in Redshift
ImmutableList.of("'hello'"))) {
assertThat(query("SELECT cast(a AS varchar(50)) FROM " + table.getName()))
.isFullyPushedDown();
assertThat(query("SELECT cast(a AS char(50)) FROM " + table.getName()))
.isNotFullyPushedDown(ProjectNode.class);
}
}
}
Loading
Loading