Skip to content

Commit

Permalink
Add support for changing column types in MongoDB connector
Browse files Browse the repository at this point in the history
  • Loading branch information
ebyhr committed Jan 30, 2023
1 parent 8d18ca2 commit d004525
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Streams;
import com.google.common.io.Closer;
import com.mongodb.client.MongoCollection;
import io.airlift.log.Logger;
Expand Down Expand Up @@ -51,8 +52,14 @@
import io.trino.spi.ptf.ConnectorTableFunctionHandle;
import io.trino.spi.security.TrinoPrincipal;
import io.trino.spi.statistics.ComputedStatistics;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.CharType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.RowType;
import io.trino.spi.type.RowType.Field;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
import org.bson.Document;

import java.io.IOException;
Expand All @@ -71,6 +78,7 @@
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.MoreCollectors.onlyElement;
import static com.mongodb.client.model.Aggregates.lookup;
import static com.mongodb.client.model.Aggregates.match;
import static com.mongodb.client.model.Aggregates.merge;
Expand All @@ -81,6 +89,11 @@
import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static io.trino.spi.type.IntegerType.INTEGER;
import static io.trino.spi.type.RealType.REAL;
import static io.trino.spi.type.SmallintType.SMALLINT;
import static io.trino.spi.type.TinyintType.TINYINT;
import static java.lang.Math.toIntExact;
import static java.lang.String.format;
import static java.nio.charset.StandardCharsets.UTF_8;
Expand Down Expand Up @@ -254,6 +267,80 @@ public void dropColumn(ConnectorSession session, ConnectorTableHandle tableHandl
mongoSession.dropColumn(((MongoTableHandle) tableHandle), ((MongoColumnHandle) column).getName());
}

@Override
public void setColumnType(ConnectorSession session, ConnectorTableHandle tableHandle, ColumnHandle columnHandle, Type type)
{
MongoTableHandle table = (MongoTableHandle) tableHandle;
MongoColumnHandle column = (MongoColumnHandle) columnHandle;
if (!canChangeColumnType(column.getType(), type)) {
throw new TrinoException(NOT_SUPPORTED, "Cannot change type from %s to %s".formatted(column.getType(), type));
}
mongoSession.setColumnType(table, column.getName(), type);
}

private static boolean canChangeColumnType(Type sourceType, Type newType)
{
if (sourceType.equals(newType)) {
return true;
}
if (sourceType == TINYINT) {
return newType == SMALLINT || newType == INTEGER || newType == BIGINT;
}
if (sourceType == SMALLINT) {
return newType == INTEGER || newType == BIGINT;
}
if (sourceType == INTEGER) {
return newType == BIGINT;
}
if (sourceType == REAL) {
return newType == DOUBLE;
}
if (sourceType instanceof VarcharType || sourceType instanceof CharType) {
return newType instanceof VarcharType || newType instanceof CharType;
}
if (sourceType instanceof DecimalType sourceDecimal && newType instanceof DecimalType newDecimal) {
return sourceDecimal.getScale() == newDecimal.getScale()
&& sourceDecimal.getPrecision() <= newDecimal.getPrecision();
}
if (sourceType instanceof ArrayType sourceArrayType && newType instanceof ArrayType newArrayType) {
return canChangeColumnType(sourceArrayType.getElementType(), newArrayType.getElementType());
}
if (sourceType instanceof RowType sourceRowType && newType instanceof RowType newRowType) {
List<Field> fields = Streams.concat(sourceRowType.getFields().stream(), newRowType.getFields().stream())
.distinct()
.collect(toImmutableList());
for (Field field : fields) {
String fieldName = field.getName().orElseThrow();
if (fieldExists(sourceRowType, fieldName) && fieldExists(newRowType, fieldName)) {
if (!canChangeColumnType(
findFieldByName(sourceRowType.getFields(), fieldName).getType(),
findFieldByName(newRowType.getFields(), fieldName).getType())) {
return false;
}
}
}
return true;
}
return false;
}

private static Field findFieldByName(List<Field> fields, String fieldName)
{
return fields.stream()
.filter(field -> field.getName().orElseThrow().equals(fieldName))
.collect(onlyElement());
}

private static boolean fieldExists(RowType structType, String fieldName)
{
for (Field field : structType.getFields()) {
if (field.getName().orElseThrow().equals(fieldName)) {
return true;
}
}
return false;
}

@Override
public ConnectorOutputTableHandle beginCreateTable(ConnectorSession session, ConnectorTableMetadata tableMetadata, Optional<ConnectorTableLayout> layout, RetryMode retryMode)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,31 @@ public void dropColumn(MongoTableHandle table, String columnName)
tableCache.invalidate(table.getSchemaTableName());
}

public void setColumnType(MongoTableHandle table, String columnName, Type type)
{
String remoteSchemaName = table.getRemoteTableName().getDatabaseName();
String remoteTableName = table.getRemoteTableName().getCollectionName();

Document metadata = getTableMetadata(remoteSchemaName, remoteTableName);

List<Document> columns = getColumnMetadata(metadata).stream()
.map(document -> {
if (document.getString(FIELDS_NAME_KEY).equals(columnName)) {
document.put(FIELDS_TYPE_KEY, type.getTypeSignature().toString());
return document;
}
return document;
})
.collect(toImmutableList());

metadata.replace(FIELDS_KEY, columns);

client.getDatabase(remoteSchemaName).getCollection(schemaCollection)
.findOneAndReplace(new Document(TABLE_NAME_KEY, remoteTableName), metadata);

tableCache.invalidate(table.getSchemaTableName());
}

private MongoTable loadTableSchema(SchemaTableName schemaTableName)
throws TableNotFoundException
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior)
return false;

case SUPPORTS_RENAME_COLUMN:
case SUPPORTS_SET_COLUMN_TYPE:
return false;

case SUPPORTS_NOT_NULL_CONSTRAINT:
Expand Down Expand Up @@ -810,6 +809,29 @@ protected void verifyTableNameLengthFailurePermissible(Throwable e)
assertThat(e).hasMessageMatching(".*fully qualified namespace .* is too long.*|Qualified identifier name must be shorter than or equal to '120'.*");
}

@Override
protected void verifySetColumnTypeFailurePermissible(Throwable e)
{
assertThat(e).hasMessageContaining("Cannot change type");
}

@Override
protected Optional<SetColumnTypeSetup> filterSetColumnTypesDataProvider(SetColumnTypeSetup setup)
{
switch ("%s -> %s".formatted(setup.sourceColumnType(), setup.newColumnType())) {
case "bigint -> integer":
case "decimal(5,3) -> decimal(5,2)":
case "time(3) -> time(6)":
case "time(6) -> time(3)":
case "timestamp(3) -> timestamp(6)":
case "timestamp(6) -> timestamp(3)":
case "timestamp(3) with time zone -> timestamp(6) with time zone":
case "timestamp(6) with time zone -> timestamp(3) with time zone":
return Optional.of(setup.asUnsupported());
}
return Optional.of(setup);
}

private void assertOneNotNullResult(String query)
{
MaterializedResult results = getQueryRunner().execute(getSession(), query).toTestTypes();
Expand Down

0 comments on commit d004525

Please sign in to comment.