diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMetadata.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMetadata.java index 6c37ed076dbcf..993eb93b83e08 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMetadata.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoMetadata.java @@ -15,6 +15,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Streams; import com.google.common.io.Closer; import com.mongodb.client.MongoCollection; import io.airlift.log.Logger; @@ -51,8 +52,14 @@ import io.trino.spi.ptf.ConnectorTableFunctionHandle; import io.trino.spi.security.TrinoPrincipal; import io.trino.spi.statistics.ComputedStatistics; +import io.trino.spi.type.ArrayType; import io.trino.spi.type.BigintType; +import io.trino.spi.type.CharType; +import io.trino.spi.type.DecimalType; +import io.trino.spi.type.RowType; +import io.trino.spi.type.RowType.Field; import io.trino.spi.type.Type; +import io.trino.spi.type.VarcharType; import org.bson.Document; import java.io.IOException; @@ -71,6 +78,7 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.MoreCollectors.onlyElement; import static com.mongodb.client.model.Aggregates.lookup; import static com.mongodb.client.model.Aggregates.match; import static com.mongodb.client.model.Aggregates.merge; @@ -81,6 +89,11 @@ import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.DoubleType.DOUBLE; +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.RealType.REAL; +import static io.trino.spi.type.SmallintType.SMALLINT; +import static io.trino.spi.type.TinyintType.TINYINT; import static java.lang.Math.toIntExact; import static java.lang.String.format; import static java.nio.charset.StandardCharsets.UTF_8; @@ -254,6 +267,80 @@ public void dropColumn(ConnectorSession session, ConnectorTableHandle tableHandl mongoSession.dropColumn(((MongoTableHandle) tableHandle), ((MongoColumnHandle) column).getName()); } + @Override + public void setColumnType(ConnectorSession session, ConnectorTableHandle tableHandle, ColumnHandle columnHandle, Type type) + { + MongoTableHandle table = (MongoTableHandle) tableHandle; + MongoColumnHandle column = (MongoColumnHandle) columnHandle; + if (!canChangeColumnType(column.getType(), type)) { + throw new TrinoException(NOT_SUPPORTED, "Cannot change type from %s to %s".formatted(column.getType(), type)); + } + mongoSession.setColumnType(table, column.getName(), type); + } + + private static boolean canChangeColumnType(Type sourceType, Type newType) + { + if (sourceType.equals(newType)) { + return true; + } + if (sourceType == TINYINT) { + return newType == SMALLINT || newType == INTEGER || newType == BIGINT; + } + if (sourceType == SMALLINT) { + return newType == INTEGER || newType == BIGINT; + } + if (sourceType == INTEGER) { + return newType == BIGINT; + } + if (sourceType == REAL) { + return newType == DOUBLE; + } + if (sourceType instanceof VarcharType || sourceType instanceof CharType) { + return newType instanceof VarcharType || newType instanceof CharType; + } + if (sourceType instanceof DecimalType sourceDecimal && newType instanceof DecimalType newDecimal) { + return sourceDecimal.getScale() == newDecimal.getScale() + && sourceDecimal.getPrecision() <= newDecimal.getPrecision(); + } + if (sourceType instanceof ArrayType sourceArrayType && newType instanceof ArrayType newArrayType) { + return canChangeColumnType(sourceArrayType.getElementType(), newArrayType.getElementType()); + } + if (sourceType instanceof RowType sourceRowType && newType instanceof RowType newRowType) { + List fields = Streams.concat(sourceRowType.getFields().stream(), newRowType.getFields().stream()) + .distinct() + .collect(toImmutableList()); + for (Field field : fields) { + String fieldName = field.getName().orElseThrow(); + if (fieldExists(sourceRowType, fieldName) && fieldExists(newRowType, fieldName)) { + if (!canChangeColumnType( + findFieldByName(sourceRowType.getFields(), fieldName).getType(), + findFieldByName(newRowType.getFields(), fieldName).getType())) { + return false; + } + } + } + return true; + } + return false; + } + + private static Field findFieldByName(List fields, String fieldName) + { + return fields.stream() + .filter(field -> field.getName().orElseThrow().equals(fieldName)) + .collect(onlyElement()); + } + + private static boolean fieldExists(RowType structType, String fieldName) + { + for (Field field : structType.getFields()) { + if (field.getName().orElseThrow().equals(fieldName)) { + return true; + } + } + return false; + } + @Override public ConnectorOutputTableHandle beginCreateTable(ConnectorSession session, ConnectorTableMetadata tableMetadata, Optional layout, RetryMode retryMode) { diff --git a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoSession.java b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoSession.java index b646a5f89e743..4dfa0a1aee0e3 100644 --- a/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoSession.java +++ b/plugin/trino-mongodb/src/main/java/io/trino/plugin/mongodb/MongoSession.java @@ -353,6 +353,31 @@ public void dropColumn(MongoTableHandle table, String columnName) tableCache.invalidate(table.getSchemaTableName()); } + public void setColumnType(MongoTableHandle table, String columnName, Type type) + { + String remoteSchemaName = table.getRemoteTableName().getDatabaseName(); + String remoteTableName = table.getRemoteTableName().getCollectionName(); + + Document metadata = getTableMetadata(remoteSchemaName, remoteTableName); + + List columns = getColumnMetadata(metadata).stream() + .map(document -> { + if (document.getString(FIELDS_NAME_KEY).equals(columnName)) { + document.put(FIELDS_TYPE_KEY, type.getTypeSignature().toString()); + return document; + } + return document; + }) + .collect(toImmutableList()); + + metadata.replace(FIELDS_KEY, columns); + + client.getDatabase(remoteSchemaName).getCollection(schemaCollection) + .findOneAndReplace(new Document(TABLE_NAME_KEY, remoteTableName), metadata); + + tableCache.invalidate(table.getSchemaTableName()); + } + private MongoTable loadTableSchema(SchemaTableName schemaTableName) throws TableNotFoundException { diff --git a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoConnectorTest.java b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoConnectorTest.java index 29c788bcb0613..be993daffbeca 100644 --- a/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoConnectorTest.java +++ b/plugin/trino-mongodb/src/test/java/io/trino/plugin/mongodb/TestMongoConnectorTest.java @@ -90,7 +90,6 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) return false; case SUPPORTS_RENAME_COLUMN: - case SUPPORTS_SET_COLUMN_TYPE: return false; case SUPPORTS_NOT_NULL_CONSTRAINT: @@ -810,6 +809,29 @@ protected void verifyTableNameLengthFailurePermissible(Throwable e) assertThat(e).hasMessageMatching(".*fully qualified namespace .* is too long.*|Qualified identifier name must be shorter than or equal to '120'.*"); } + @Override + protected void verifySetColumnTypeFailurePermissible(Throwable e) + { + assertThat(e).hasMessageContaining("Cannot change type"); + } + + @Override + protected Optional filterSetColumnTypesDataProvider(SetColumnTypeSetup setup) + { + switch ("%s -> %s".formatted(setup.sourceColumnType(), setup.newColumnType())) { + case "bigint -> integer": + case "decimal(5,3) -> decimal(5,2)": + case "time(3) -> time(6)": + case "time(6) -> time(3)": + case "timestamp(3) -> timestamp(6)": + case "timestamp(6) -> timestamp(3)": + case "timestamp(3) with time zone -> timestamp(6) with time zone": + case "timestamp(6) with time zone -> timestamp(3) with time zone": + return Optional.of(setup.asUnsupported()); + } + return Optional.of(setup); + } + private void assertOneNotNullResult(String query) { MaterializedResult results = getQueryRunner().execute(getSession(), query).toTestTypes();